{
  "cells": [
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "Vq31CdSRpgkI"
      },
      "source": [
        "# Customizing embeddings\n",
        "\n",
        "This notebook demonstrates one way to customize OpenAI embeddings to a particular task.\n",
        "\n",
        "The input is training data in the form of [text_1, text_2, label] where label is +1 if the pairs are similar and -1 if the pairs are dissimilar.\n",
        "\n",
        "The output is a matrix that you can use to multiply your embeddings. The product of this multiplication is a 'custom embedding' that will better emphasize aspects of the text relevant to your use case. In binary classification use cases, we've seen error rates drop by as much as 50%.\n",
        "\n",
        "In the following example, I use 1,000 sentence pairs picked from the SNLI corpus. Each pair of sentences are logically entailed (i.e., one implies the other). These pairs are our positives (label = 1). We generate synthetic negatives by combining sentences from different pairs, which are presumed to not be logically entailed (label = -1).\n",
        "\n",
        "For a clustering use case, you can generate positives by creating pairs from texts in the same clusters and generate negatives by creating pairs from sentences in different clusters.\n",
        "\n",
        "With other data sets, we have seen decent improvement with as little as ~100 training examples. Of course, performance will be better with  more examples."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "arB38jFwpgkK"
      },
      "source": [
        "# 0. Imports"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "ifvM7g4apgkK"
      },
      "outputs": [],
      "source": [
        "# imports\n",
        "from typing import List, Tuple  # for type hints\n",
        "\n",
        "import numpy as np  # for manipulating arrays\n",
        "import pandas as pd  # for manipulating data in dataframes\n",
        "import pickle  # for saving the embeddings cache\n",
        "import plotly.express as px  # for plots\n",
        "import random  # for generating run IDs\n",
        "from sklearn.model_selection import train_test_split  # for splitting train & test data\n",
        "import torch  # for matrix optimization\n",
        "\n",
        "from utils.embeddings_utils import get_embedding, cosine_similarity  # for embeddings\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "DtBbryAapgkL"
      },
      "source": [
        "## 1. Inputs\n",
        "\n",
        "Most inputs are here. The key things to change are where to load your datset from, where to save a cache of embeddings to, and which embedding engine you want to use.\n",
        "\n",
        "Depending on how your data is formatted, you'll want to rewrite the process_input_data function."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "UzxcWRCkpgkM"
      },
      "outputs": [],
      "source": [
        "# input parameters\n",
        "embedding_cache_path = \"data/snli_embedding_cache.pkl\"  # embeddings will be saved/loaded here\n",
        "default_embedding_engine = \"text-embedding-3-small\"\n",
        "num_pairs_to_embed = 1000  # 1000 is arbitrary\n",
        "local_dataset_path = \"data/snli_1.0_train_2k.csv\"  # download from: https://nlp.stanford.edu/projects/snli/\n",
        "\n",
        "\n",
        "def process_input_data(df: pd.DataFrame) -> pd.DataFrame:\n",
        "    # you can customize this to preprocess your own dataset\n",
        "    # output should be a dataframe with 3 columns: text_1, text_2, label (1 for similar, -1 for dissimilar)\n",
        "    df[\"label\"] = df[\"gold_label\"]\n",
        "    df = df[df[\"label\"].isin([\"entailment\"])]\n",
        "    df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n",
        "    df = df.rename(columns={\"sentence1\": \"text_1\", \"sentence2\": \"text_2\"})\n",
        "    df = df[[\"text_1\", \"text_2\", \"label\"]]\n",
        "    df = df.head(num_pairs_to_embed)\n",
        "    return df\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "aBbH71hEpgkM"
      },
      "source": [
        "## 2. Load and process input data"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "id": "kAKLjYG6pgkN",
        "outputId": "dc178688-e97d-4ad0-b26c-dff67b858966"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/var/folders/r4/x3kdvs816995fnnph2gdpwp40000gn/T/ipykernel_17509/1977422881.py:13: SettingWithCopyWarning: \n",
            "A value is trying to be set on a copy of a slice from a DataFrame.\n",
            "Try using .loc[row_indexer,col_indexer] = value instead\n",
            "\n",
            "See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
            "  df[\"label\"] = df[\"label\"].apply(lambda x: {\"entailment\": 1, \"contradiction\": -1}[x])\n"
          ]
        },
        {
          "data": {
            "text/html": [
              "<div>\n",
              "<style scoped>\n",
              "    .dataframe tbody tr th:only-of-type {\n",
              "        vertical-align: middle;\n",
              "    }\n",
              "\n",
              "    .dataframe tbody tr th {\n",
              "        vertical-align: top;\n",
              "    }\n",
              "\n",
              "    .dataframe thead th {\n",
              "        text-align: right;\n",
              "    }\n",
              "</style>\n",
              "<table border=\"1\" class=\"dataframe\">\n",
              "  <thead>\n",
              "    <tr style=\"text-align: right;\">\n",
              "      <th></th>\n",
              "      <th>text_1</th>\n",
              "      <th>text_2</th>\n",
              "      <th>label</th>\n",
              "    </tr>\n",
              "  </thead>\n",
              "  <tbody>\n",
              "    <tr>\n",
              "      <th>2</th>\n",
              "      <td>A person on a horse jumps over a broken down a...</td>\n",
              "      <td>A person is outdoors, on a horse.</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>4</th>\n",
              "      <td>Children smiling and waving at camera</td>\n",
              "      <td>There are children present</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>7</th>\n",
              "      <td>A boy is jumping on skateboard in the middle o...</td>\n",
              "      <td>The boy does a skateboarding trick.</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>14</th>\n",
              "      <td>Two blond women are hugging one another.</td>\n",
              "      <td>There are women showing affection.</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "    <tr>\n",
              "      <th>17</th>\n",
              "      <td>A few people in a restaurant setting, one of t...</td>\n",
              "      <td>The diners are at a restaurant.</td>\n",
              "      <td>1</td>\n",
              "    </tr>\n",
              "  </tbody>\n",
              "</table>\n",
              "</div>"
            ],
            "text/plain": [
              "                                               text_1  \\\n",
              "2   A person on a horse jumps over a broken down a...   \n",
              "4               Children smiling and waving at camera   \n",
              "7   A boy is jumping on skateboard in the middle o...   \n",
              "14           Two blond women are hugging one another.   \n",
              "17  A few people in a restaurant setting, one of t...   \n",
              "\n",
              "                                 text_2  label  \n",
              "2     A person is outdoors, on a horse.      1  \n",
              "4            There are children present      1  \n",
              "7   The boy does a skateboarding trick.      1  \n",
              "14   There are women showing affection.      1  \n",
              "17      The diners are at a restaurant.      1  "
            ]
          },
          "execution_count": 3,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "# load data\n",
        "df = pd.read_csv(local_dataset_path)\n",
        "\n",
        "# process input data\n",
        "df = process_input_data(df)  # this demonstrates training data containing only positives\n",
        "\n",
        "# view data\n",
        "df.head()\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "z2F1cCoYpgkO"
      },
      "source": [
        "## 3. Split data into training test sets\n",
        "\n",
        "Note that it's important to split data into training and test sets *before* generating synethetic negatives or positives. You don't want any text strings in the training data to show up in the test data. If there's contamination, the test metrics will look better than they'll actually be in production."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "id": "50QmnH2qpgkO",
        "outputId": "6144029b-eb29-439e-9990-7aeb28168e56"
      },
      "outputs": [],
      "source": [
        "# split data into train and test sets\n",
        "test_fraction = 0.5  # 0.5 is fairly arbitrary\n",
        "random_seed = 123  # random seed is arbitrary, but is helpful in reproducibility\n",
        "train_df, test_df = train_test_split(\n",
        "    df, test_size=test_fraction, stratify=df[\"label\"], random_state=random_seed\n",
        ")\n",
        "train_df.loc[:, \"dataset\"] = \"train\"\n",
        "test_df.loc[:, \"dataset\"] = \"test\"\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "MzAFkA2opgkP"
      },
      "source": [
        "## 4. Generate synthetic negatives\n",
        "\n",
        "This is another piece of the code that you will need to modify to match your use case.\n",
        "\n",
        "If you have data with positives and negatives, you can skip this section.\n",
        "\n",
        "If you have data with only positives, you can mostly keep it as is, where it generates negatives only.\n",
        "\n",
        "If you have multiclass data, you will want to generate both positives and negatives. The positives can be pairs of text that share labels, and the negatives can be pairs of text that do not share labels.\n",
        "\n",
        "The final output should be a dataframe with text pairs, where each pair is labeled -1 or 1."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "id": "rUYd9V0zpgkP"
      },
      "outputs": [],
      "source": [
        "# generate negatives\n",
        "def dataframe_of_negatives(dataframe_of_positives: pd.DataFrame) -> pd.DataFrame:\n",
        "    \"\"\"Return dataframe of negative pairs made by combining elements of positive pairs.\"\"\"\n",
        "    texts = set(dataframe_of_positives[\"text_1\"].values) | set(\n",
        "        dataframe_of_positives[\"text_2\"].values\n",
        "    )\n",
        "    all_pairs = {(t1, t2) for t1 in texts for t2 in texts if t1 < t2}\n",
        "    positive_pairs = set(\n",
        "        tuple(text_pair)\n",
        "        for text_pair in dataframe_of_positives[[\"text_1\", \"text_2\"]].values\n",
        "    )\n",
        "    negative_pairs = all_pairs - positive_pairs\n",
        "    df_of_negatives = pd.DataFrame(list(negative_pairs), columns=[\"text_1\", \"text_2\"])\n",
        "    df_of_negatives[\"label\"] = -1\n",
        "    return df_of_negatives\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "id": "Rkh8-J89pgkP"
      },
      "outputs": [],
      "source": [
        "negatives_per_positive = (\n",
        "    1  # it will work at higher values too, but more data will be slower\n",
        ")\n",
        "# generate negatives for training dataset\n",
        "train_df_negatives = dataframe_of_negatives(train_df)\n",
        "train_df_negatives[\"dataset\"] = \"train\"\n",
        "# generate negatives for test dataset\n",
        "test_df_negatives = dataframe_of_negatives(test_df)\n",
        "test_df_negatives[\"dataset\"] = \"test\"\n",
        "# sample negatives and combine with positives\n",
        "train_df = pd.concat(\n",
        "    [\n",
        "        train_df,\n",
        "        train_df_negatives.sample(\n",
        "            n=len(train_df) * negatives_per_positive, random_state=random_seed\n",
        "        ),\n",
        "    ]\n",
        ")\n",
        "test_df = pd.concat(\n",
        "    [\n",
        "        test_df,\n",
        "        test_df_negatives.sample(\n",
        "            n=len(test_df) * negatives_per_positive, random_state=random_seed\n",
        "        ),\n",
        "    ]\n",
        ")\n",
        "\n",
        "df = pd.concat([train_df, test_df])\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "8MVSLMSrpgkQ"
      },
      "source": [
        "## 5. Calculate embeddings and cosine similarities\n",
        "\n",
        "Here, I create a cache to save the embeddings. This is handy so that you don't have to pay again if you want to run the code again."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "R6tWgS_ApgkQ"
      },
      "outputs": [],
      "source": [
        "# establish a cache of embeddings to avoid recomputing\n",
        "# cache is a dict of tuples (text, engine) -> embedding\n",
        "try:\n",
        "    with open(embedding_cache_path, \"rb\") as f:\n",
        "        embedding_cache = pickle.load(f)\n",
        "except FileNotFoundError:\n",
        "    precomputed_embedding_cache_path = \"https://cdn.openai.com/API/examples/data/snli_embedding_cache.pkl\"\n",
        "    embedding_cache = pd.read_pickle(precomputed_embedding_cache_path)\n",
        "\n",
        "\n",
        "# this function will get embeddings from the cache and save them there afterward\n",
        "def get_embedding_with_cache(\n",
        "    text: str,\n",
        "    engine: str = default_embedding_engine,\n",
        "    embedding_cache: dict = embedding_cache,\n",
        "    embedding_cache_path: str = embedding_cache_path,\n",
        ") -> list:\n",
        "    if (text, engine) not in embedding_cache.keys():\n",
        "        # if not in cache, call API to get embedding\n",
        "        embedding_cache[(text, engine)] = get_embedding(text, engine)\n",
        "        # save embeddings cache to disk after each update\n",
        "        with open(embedding_cache_path, \"wb\") as embedding_cache_file:\n",
        "            pickle.dump(embedding_cache, embedding_cache_file)\n",
        "    return embedding_cache[(text, engine)]\n",
        "\n",
        "\n",
        "# create column of embeddings\n",
        "for column in [\"text_1\", \"text_2\"]:\n",
        "    df[f\"{column}_embedding\"] = df[column].apply(get_embedding_with_cache)\n",
        "\n",
        "# create column of cosine similarity between embeddings\n",
        "df[\"cosine_similarity\"] = df.apply(\n",
        "    lambda row: cosine_similarity(row[\"text_1_embedding\"], row[\"text_2_embedding\"]),\n",
        "    axis=1,\n",
        ")\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "4pwn608LpgkQ"
      },
      "source": [
        "## 6. Plot distribution of cosine similarity\n",
        "\n",
        "Here we measure similarity of text using cosine similarity. In our experience, most distance functions (L1, L2, cosine similarity) all work about the same. Note that our embeddings are already normalized to length 1, so cosine similarity is equivalent to dot product.\n",
        "\n",
        "The graphs show how much the overlap there is between the distribution of cosine similarities for similar and dissimilar pairs. If there is a high amount of overlap, that means there are some dissimilar pairs with greater cosine similarity than some similar pairs.\n",
        "\n",
        "The accuracy I compute is the accuracy of a simple rule that predicts 'similar (1)' if the cosine similarity is above some threshold X and otherwise predicts 'dissimilar (0)'."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "id": "SoeDF8vqpgkQ",
        "outputId": "17db817e-1702-4089-c4e8-8ca32d294930"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.plotly.v1+json": {
              "config": {
                "plotlyServerURL": "https://plot.ly"
              },
              "data": [
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.9267355919345323,
                    0.8959824209230313,
                    0.9119725922265434,
                    0.854066984766886,
                    0.892887342475088,
                    0.9197115510183956,
                    0.8645429616364463,
                    0.8314164166177409,
                    0.7243313891695883,
                    0.8819971504547897,
                    0.7956215051893748,
                    0.7959481856578454,
                    0.8682525487547492,
                    0.8973704561532523,
                    0.8648042605746787,
                    0.9236698981903941,
                    0.9834804742323746,
                    0.8152447410515193,
                    0.8251720025778885,
                    0.8138195587168373,
                    0.8041885875094662,
                    0.9329690881953521,
                    0.9560346908585096,
                    0.9727875559800526,
                    0.8739787488907723,
                    0.8208200937845748,
                    0.7246913159387628,
                    0.9324916305400826,
                    0.8285737172791849,
                    0.8797008558439083,
                    0.820333215489803,
                    0.9370111008629193,
                    0.8983827475968527,
                    0.8312111255338568,
                    0.8164052516323846,
                    0.8908148647559723,
                    0.7466264012908705,
                    0.749651964301876,
                    0.87375582683117,
                    0.7849398161817545,
                    0.8309506403579568,
                    0.9307212180139316,
                    0.8281747408531835,
                    0.9529528469109777,
                    0.7828662080109449,
                    0.887100957550481,
                    0.9000278356226864,
                    0.8805448739521785,
                    0.930239131235984,
                    0.8801954897699975,
                    0.8529206900756547,
                    0.9467797362617967,
                    0.950367690540818,
                    0.7030531843080301,
                    0.8643992401506337,
                    0.8536886651933216,
                    0.9619331104636477,
                    0.9798279220663803,
                    0.8545739729953584,
                    0.8957115039259408,
                    0.8241137169318955,
                    0.8234984837204449,
                    0.8936706246811023,
                    0.8987178417246647,
                    0.9081806514389928,
                    0.9208852073112466,
                    0.8961858080980447,
                    0.8831329491347061,
                    0.9282623092166472,
                    0.8990849228018752,
                    0.8284548395672304,
                    0.8202091311790463,
                    0.8647762700375631,
                    0.840136958242196,
                    0.9887387562448309,
                    0.833342655983456,
                    0.828533112298018,
                    0.9117757392019592,
                    0.8706628948983062,
                    0.9279786042901599,
                    0.7389559895894061,
                    0.8433932035736679,
                    0.9240307526722153,
                    0.950769936374633,
                    0.8586024431762636,
                    0.8685107120156628,
                    0.875535036209879,
                    0.9894909159640456,
                    0.8279650063634755,
                    0.9108703739124493,
                    0.9090161898700972,
                    0.8603952576151226,
                    0.7791958170479588,
                    0.8800175078297319,
                    0.8442387843179446,
                    0.7672266106337577,
                    0.9379753906877134,
                    0.8637536227406404,
                    0.9190295684769377,
                    0.813748744480001,
                    0.9134886375270357,
                    0.8043760086358624,
                    0.8792300492020616,
                    0.8716796296133559,
                    0.8669146731224541,
                    0.7736224659067634,
                    0.9437235456545872,
                    0.905686329875817,
                    0.9534823418409002,
                    0.9150626356433706,
                    0.9409873570135594,
                    0.8111212499450311,
                    0.9171209901271343,
                    0.9126582215550586,
                    0.8337042981493767,
                    0.7317859043960847,
                    0.8444929460069593,
                    0.8561920422842992,
                    0.7765276739806221,
                    0.8526116779814181,
                    0.9178549171995068,
                    0.9238337665547537,
                    0.7218806795387105,
                    0.8180162420894919,
                    0.9687438998025378,
                    0.8354559162714565,
                    0.9146160667909478,
                    0.8082103463995212,
                    0.9563444955040953,
                    0.9066029892990775,
                    0.8485102485675593,
                    0.8154210965438722,
                    0.8860338155558548,
                    0.9280705420457523,
                    0.9835182434488767,
                    0.9653797793986199,
                    0.7815047672209323,
                    0.7156150747063292,
                    0.9256075945804699,
                    0.8135073898727201,
                    0.965501517353668,
                    0.8222681603043882,
                    0.907212186549752,
                    0.8611990749004483,
                    0.9075083701591861,
                    0.9452697087352507,
                    0.8792221638449876,
                    0.9261547230875113,
                    0.8628843081873457,
                    0.7825774684929468,
                    0.826587869384389,
                    0.7399697967202743,
                    0.7855475169419611,
                    0.9111614040328055,
                    0.8871057104665023,
                    0.8824403169032524,
                    0.8618250237448342,
                    0.9787037901590712,
                    0.8066230600465234,
                    0.82769109211187,
                    0.9246432066709112,
                    0.8840853142719706,
                    0.786484352300887,
                    0.9106863314452291,
                    0.9342800776988389,
                    0.8573335079738363,
                    0.7780871062391552,
                    0.7913314665963198,
                    0.8574377397709955,
                    0.9078366114351649,
                    0.7529739278117176,
                    0.8630106564789936,
                    0.9051526759332171,
                    0.7715467442267975,
                    0.894146587858075,
                    0.8095881901541373,
                    0.7733578316514722,
                    0.7600408374028372,
                    0.7819972017869313,
                    0.9003461714649055,
                    0.7428048011790054,
                    0.8645936498599726,
                    0.8158769881622202,
                    0.8338827592994027,
                    0.8272653967731632,
                    0.9017517376316297,
                    0.8480852025321463,
                    0.7970818331146111,
                    0.84837067058732,
                    0.9272909953469497,
                    0.9511439765702141,
                    0.8796630924442385,
                    0.8297595339070903,
                    0.8132311695727563,
                    0.846096509352579,
                    0.8787645385610889,
                    0.8591367322994465,
                    0.8452813265723063,
                    0.708120854745005,
                    0.8769677220320785,
                    0.957621649201,
                    0.7463356296247886,
                    0.8618039385828121,
                    0.9560112462834625,
                    0.8478374736767776,
                    0.769289015775232,
                    0.8456866935685449,
                    0.9014601942765245,
                    0.8816990623836058,
                    0.8836365012367311,
                    0.8078009762348856,
                    0.898471670333294,
                    0.9064470715463968,
                    0.8762712610709371,
                    0.9178852315161201,
                    0.7896235958446132,
                    0.8939345739482307,
                    0.9534018415944343,
                    0.8358882135213556,
                    0.9488657107840998,
                    0.9046799883600772,
                    0.7583576517359092,
                    0.9080459939022663,
                    0.7709722685822517,
                    0.9635512477648485,
                    0.9792712672362888,
                    0.8526700765974843,
                    0.8278133097990042,
                    0.9735858611728696,
                    0.721230194834802,
                    0.8257425311085005,
                    0.9243205490037274,
                    0.9183796453898277,
                    0.9029146937248601,
                    0.9410246048181872,
                    0.9609604036664618,
                    0.7467407974920351,
                    0.8831901217813377,
                    0.8173287200784538,
                    0.8067347032404598,
                    0.7921957436175584,
                    0.9110994793014074,
                    0.8678737504217436,
                    0.9117743220816726,
                    0.781256498099708,
                    0.8553931170417658,
                    0.8798565764815073,
                    0.8485358179238083,
                    0.7748765495278522,
                    0.9432062986428599,
                    0.8328320716129907,
                    0.7983629771463491,
                    0.9345589964322458,
                    0.780034700168418,
                    0.9894680324844076,
                    0.8239308905190431,
                    0.8236003498823307,
                    0.8346101074957768,
                    0.8273793488443836,
                    0.7872103673165679,
                    0.9502897884724039,
                    0.83306630344504,
                    0.9346568240975981,
                    0.8082083566882774,
                    0.8920672687911747,
                    0.8566523137465843,
                    0.7636170858064102,
                    0.8271498146927989,
                    0.8450776675259235,
                    0.9045266249905214,
                    0.8578964053464326,
                    0.8673866119197355,
                    0.8804224181917263,
                    0.8199459545727489,
                    0.932410033971128,
                    0.9096821262750205,
                    0.8658255622247675,
                    0.9386720385226357,
                    0.8517830114204678,
                    0.889433736353776,
                    0.978847593529176,
                    0.8369738178686744,
                    0.8438616304896431,
                    0.9457050132083015,
                    0.8699723446274389,
                    0.7795221418941651,
                    0.9136284834408402,
                    0.839461038218975,
                    0.9453279809479284,
                    0.7899532071809374,
                    0.9078373589637019,
                    0.8434980548175782,
                    0.8112068688921613,
                    0.9466506411385276,
                    0.931413666561485,
                    0.7932453730948084,
                    0.8205411414836395,
                    0.9243834389669592,
                    0.719616209472987,
                    0.7552985082978257,
                    0.9593440978701407,
                    0.917557937781671,
                    0.8643861907339583,
                    0.8315201130646898,
                    0.7608819746989568,
                    0.9704324558544556,
                    0.8037085307118571,
                    0.7785270629127378,
                    0.8044961889638995,
                    0.8313307525316546,
                    0.8064106357955508,
                    0.9291149169587132,
                    0.8412940941318643,
                    0.6917091086045903,
                    0.895204432897799,
                    0.8182250729145697,
                    0.8645847241904496,
                    0.8532020284403578,
                    0.8143634597102418,
                    0.8825747762544934,
                    0.7764652529619696,
                    0.850099367461302,
                    0.8616919096196407,
                    0.9257293995599788,
                    0.935772204341753,
                    0.7742657205171264,
                    0.7898710076766889,
                    0.8590438495382458,
                    0.9317809680608741,
                    0.9087109944408146,
                    0.949297999227784,
                    0.8813316524294974,
                    0.7372081408133433,
                    0.8838176418404169
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.9424796847531719,
                    0.9078956620721231,
                    0.8334324872928349,
                    0.9352180099421901,
                    0.9055462980371385,
                    0.8981939710469806,
                    0.8310153256803275,
                    0.8504676050313356,
                    0.8456281884302819,
                    0.8845204616348977,
                    0.9575409743129409,
                    0.8867362111499236,
                    0.8268148049156373,
                    0.9197424487445726,
                    0.7868932880014918,
                    0.7584994087629051,
                    0.9184151106611779,
                    0.8634069821364065,
                    0.8347803687511831,
                    0.8293627315810226,
                    0.9290633380383377,
                    0.8385821691321573,
                    0.9389267221382397,
                    0.890818442649049,
                    0.86634760562554,
                    0.8406483291061461,
                    0.8084243397427517,
                    0.8909500804242533,
                    0.9262896019507018,
                    0.8955541231110083,
                    0.8055268512037249,
                    0.758607678624549,
                    0.9609058491722305,
                    0.91495905751538,
                    0.8670137150812314,
                    0.8813831602682358,
                    0.8602255150075953,
                    0.9239960998195261,
                    0.91732217769794,
                    0.8037285395908439,
                    0.9196344979660896,
                    0.8179495018382927,
                    0.9015423003762686,
                    0.9054394610623084,
                    0.9309412941515826,
                    0.9421722892808463,
                    0.7632823176731237,
                    0.8622055676377434,
                    0.9855273108710355,
                    0.9144155566597573,
                    0.916057392559463,
                    0.8027504546689012,
                    0.7131090055200247,
                    0.86174194818737,
                    0.982873171115144,
                    0.8100227539053508,
                    0.8923878603823862,
                    0.8096643432711604,
                    0.8707613707684018,
                    0.8786740148818871,
                    0.8274639911687568,
                    0.8927098767487393,
                    0.9565597075186062,
                    0.9060728095743347,
                    0.738307517030644,
                    0.9645943657917937,
                    0.8755564011033787,
                    0.879644342330288,
                    0.8679709669180851,
                    0.9304235148160597,
                    0.8902804960376421,
                    0.8748369557157689,
                    0.9999999999999999,
                    0.7979398167195422,
                    0.8182553472910762,
                    0.7782108671759803,
                    0.8427610550014142,
                    0.8696408842646327,
                    0.8747903026537825,
                    0.914973368517656,
                    0.9651568968718233,
                    0.9775547988384379,
                    0.8964005885715726,
                    0.8689760349348122,
                    0.8501707274497268,
                    0.9069421089316828,
                    0.7682621587597694,
                    0.9658683147667629,
                    0.8946443487011166,
                    0.7855154267442119,
                    0.8963791544804274,
                    0.8062904921192978,
                    0.8165205978948239,
                    0.8392522254625653,
                    0.9456080863923884,
                    0.7904904126133966,
                    0.833126792097794,
                    0.7852156051245299,
                    0.7859162398886373,
                    0.9097674992318959,
                    0.8868692153350769,
                    0.9391826646753169,
                    0.9428151202937937,
                    0.7923603877885725,
                    0.9018727189911658,
                    0.9723161942344292,
                    0.7820369113228325,
                    0.9667234201162873,
                    0.978769627389609,
                    0.9155729781931277,
                    0.8273013970664075,
                    0.9603319621485501,
                    0.9298975009081959,
                    0.8775117467919693,
                    0.8614509568162568,
                    0.9144155657624043,
                    0.7783710792186642,
                    0.9701880190033496,
                    0.7858944693777298,
                    0.9278353490304795,
                    0.9472367444800102,
                    0.7834809788888359,
                    0.7997358978342995,
                    0.8459052935830644,
                    0.8612076995259889,
                    0.8470901722260982,
                    0.824037271004761,
                    0.865608650068826,
                    0.8023193246081538,
                    0.7836788857151791,
                    0.880404135119559,
                    0.8491559249509729,
                    0.788345270395367,
                    0.9461393747813323,
                    0.835123385080455,
                    0.8158174048388718,
                    0.8604581300972295,
                    0.9623616555004862,
                    0.8564688397784819,
                    0.8576867681204893,
                    0.8973905356978807,
                    0.8634447095761868,
                    0.8149528594606367,
                    0.873171253801101,
                    0.8653347676818378,
                    0.929525558735665,
                    0.8358267202818717,
                    0.971888682386554,
                    0.8500189240129448,
                    0.6201715858790215,
                    0.8982737437906866,
                    0.8919523978597029,
                    0.7327218620224804,
                    0.8329671225228632,
                    0.9265589851585685,
                    0.8976605724257473,
                    0.8865148832512937,
                    0.7893917264917192,
                    0.7303107673595587,
                    0.8428958487387793,
                    0.8712646524042454,
                    0.9726111208329614,
                    0.9368020234351375,
                    0.9270010843622818,
                    0.8900608737128451,
                    0.7975173141451626,
                    0.9403308743666237,
                    0.8484005148509558,
                    0.9285585494125882,
                    0.8461714640960527,
                    0.9301612552439464,
                    0.9840391344904358,
                    0.8305503032909712,
                    0.8985536902480058,
                    0.9476344055766088,
                    0.9342892661789283,
                    0.8849523247638441,
                    0.7736620851030366,
                    0.8083290901088768,
                    0.9510007696957726,
                    0.8677438102591386,
                    0.8324233959261911,
                    0.7379868650067231,
                    0.9049462205283083,
                    0.9044068964151009,
                    0.7810399099399672,
                    0.9040280419401343,
                    0.7720832557964016,
                    0.7168259249496903,
                    0.8657076231674912,
                    0.9689982290529879,
                    0.9330371348632155,
                    0.7014093149115691,
                    0.9056081832768474,
                    0.8483474394912361,
                    0.8729108893663646,
                    0.8494252835407102,
                    0.8702668029663239,
                    0.8703072652243842,
                    0.9279473628052111,
                    0.8615930026010632,
                    0.7590822846968434,
                    0.8435232136599701,
                    0.8264379743713927,
                    0.8793126202650794,
                    0.847452301256391,
                    0.7546334370590053,
                    0.8870818568791698,
                    0.8349553719854912,
                    0.9232007589383142,
                    0.7924421895458662,
                    0.8556103146657943,
                    0.8397958720336148,
                    0.9358165878534705,
                    0.904577353436614,
                    0.9022537114781464,
                    0.775603917181226,
                    0.946091618927627,
                    0.8264119461162666,
                    0.8261258105029201,
                    0.8605336598142989,
                    0.7518422489499549,
                    0.849587557879856,
                    0.9922578397415717,
                    0.7499254104864422,
                    0.8845204616348977,
                    0.8361936554693772,
                    0.9172228808488442,
                    0.8068135566092208,
                    0.795739929008372,
                    0.8632611464779958,
                    0.7612462576141025,
                    0.9589125421073597,
                    0.9555759038945358,
                    0.8822980105025566,
                    0.9663740139243834,
                    0.9071760951442449,
                    0.9335338894977118,
                    0.8042262160785201,
                    0.9399607295068667,
                    0.8318513711395181,
                    0.8697471272873054,
                    0.9103391819002411,
                    0.8272582065159091,
                    0.7868989539137853,
                    0.7416168920325092,
                    0.8828593510646834,
                    0.9141342991325323,
                    0.7259887492833588,
                    0.9478299721997572,
                    0.843766518385904,
                    0.919830425538435,
                    0.9069062941852448,
                    0.9036466185261808,
                    0.9817542893707696,
                    0.8833292621745382,
                    0.832556614533968,
                    0.8135910443631924,
                    0.9628932969024508,
                    0.9450804655496136,
                    0.9226384091529121,
                    0.8401818103049188,
                    0.723691406760321,
                    0.6828741135249211,
                    0.834410523069395,
                    0.9959256404542386,
                    0.952870396433041,
                    0.969514692554192,
                    0.9220387806044666,
                    0.9511950116111251,
                    0.87442203104197,
                    0.8399026046246612,
                    0.9029483760650204,
                    0.9097073428917352,
                    0.8651925582004045,
                    0.9178332691819033,
                    0.7556713752294848,
                    0.8601740894614596,
                    0.8250804250840363,
                    0.799473306929639,
                    0.8911389639861809,
                    0.915913776235107,
                    0.7867422041389165,
                    0.8035116695233039,
                    0.7702882636946234,
                    0.9060460430333088,
                    0.7214029229730072,
                    0.8607904806397634,
                    0.8228468643082103,
                    0.8900020169140401,
                    0.9343567736626528,
                    0.9305049279291139,
                    0.9664193138643489,
                    0.9008537853184969,
                    0.7625840742620333,
                    0.815302054727336,
                    0.9215061720798934,
                    0.7192673801865671,
                    0.8949994067748966,
                    0.9367566547265034,
                    0.7602684166275758,
                    0.8184439767612992,
                    0.8361983856596491,
                    0.7761725471827079,
                    0.7724780968772909,
                    0.9249211346782868,
                    0.8718843131924867,
                    0.8522890335712519,
                    0.9015475867709434,
                    0.8720699810318118,
                    0.8937599387455695,
                    0.8721713573852221,
                    0.8100783166142076,
                    1.0000000000000002,
                    0.8213222537973748,
                    0.8361185401136565,
                    0.8371907459006128,
                    0.9065697385076582,
                    0.752240671472798,
                    0.8283078905766531,
                    0.8499886819287953,
                    0.9097932369637356,
                    0.9529813104528191,
                    0.8449289750214674,
                    1,
                    0.8302949362084788,
                    0.7741532046500113,
                    0.8743828037305432,
                    0.8201855611163867,
                    0.8194689758101458,
                    0.7925076796225758,
                    0.8748126117575765,
                    0.8299510305557958,
                    0.9619426561868236,
                    0.8627070029199212
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.7299945446332757,
                    0.761829365033925,
                    0.676235270353631,
                    0.7023016593603112,
                    0.7350156032869306,
                    0.7735236691667362,
                    0.7187241641280292,
                    0.8063818744486255,
                    0.7115273138749911,
                    0.7402039545462252,
                    0.7372226375565185,
                    0.719606336264161,
                    0.8344052997670786,
                    0.6881482918877712,
                    0.650513335277676,
                    0.7182967436895931,
                    0.7659970793010594,
                    0.6361476177437694,
                    0.7957002158983555,
                    0.7395402073375513,
                    0.7614384854511241,
                    0.6835300790002988,
                    0.6209194558826169,
                    0.7907726220786139,
                    0.7215750502680501,
                    0.7271947538454276,
                    0.6962333614625641,
                    0.7517476038206041,
                    0.7135529871009506,
                    0.7522919529124952,
                    0.7120639628754573,
                    0.7623014780353815,
                    0.7939574492876662,
                    0.6998873138766412,
                    0.7700594720774577,
                    0.7766161530342343,
                    0.7285080945118152,
                    0.7562511166739386,
                    0.8086622737220109,
                    0.7565297004511734,
                    0.7315242427462245,
                    0.791225232078265,
                    0.7281092467134649,
                    0.7675685917886341,
                    0.7122436997016454,
                    0.7600255476588682,
                    0.769465992253443,
                    0.7552507233110458,
                    0.7373719614604455,
                    0.7681449827665134,
                    0.7760194768221379,
                    0.8035677492769473,
                    0.7771752906080505,
                    0.6514739683312004,
                    0.744787832649546,
                    0.7700232252518491,
                    0.6901772464238046,
                    0.721128796446804,
                    0.686814078974691,
                    0.796786900996621,
                    0.8242176320872988,
                    0.7806742901384496,
                    0.7697696656361902,
                    0.7868668497422822,
                    0.7304548331410279,
                    0.7296767182508448,
                    0.699401040331297,
                    0.8332660282838579,
                    0.6513224421864793,
                    0.6927364371405096,
                    0.7491793002300279,
                    0.7909034218879171,
                    0.7754176150761527,
                    0.8004827006684735,
                    0.6659923526948868,
                    0.8129618884980553,
                    0.7496476113488948,
                    0.6519665061955423,
                    0.7319506042291191,
                    0.7367099004846358,
                    0.8590817275589286,
                    0.6684490623528697,
                    0.7469140136676841,
                    0.7269016830127544,
                    0.6834261236240542,
                    0.6390380393123325,
                    0.7129209978685723,
                    0.6879274953497815,
                    0.6720427402498903,
                    0.7343923261152341,
                    0.5977941468082364,
                    0.7574521074337901,
                    0.7302129263006347,
                    0.7380209108764002,
                    0.7280177437983031,
                    0.6870880300968067,
                    0.6928024686173956,
                    0.6403900073451696,
                    0.7209247906698092,
                    0.666799070142464,
                    0.7233569397428488,
                    0.7555267048881131,
                    0.7275931437975757,
                    0.7562785127332186,
                    0.736649021802634,
                    0.762569980781232,
                    0.7741923768467133,
                    0.6669773662198453,
                    0.7499779113599104,
                    0.7783371410262362,
                    0.7798574909618832,
                    0.7068277719647446,
                    0.7718066533969561,
                    0.7078874155127759,
                    0.7238814624322412,
                    0.8729026036201937,
                    0.7106759057210158,
                    0.7585767440008555,
                    0.6923565683937588,
                    0.6693898561996529,
                    0.7219944003835542,
                    0.7188064794656569,
                    0.7491451951131076,
                    0.6750197249394477,
                    0.7009868838756784,
                    0.6519589230722103,
                    0.775109860666869,
                    0.6682014813660948,
                    0.6618358724923147,
                    0.6218362070243146,
                    0.7518134154555007,
                    0.7571307427362168,
                    0.7546823360234822,
                    0.7416435744760188,
                    0.7676464608286054,
                    0.5799855321635428,
                    0.796191792361925,
                    0.6845373357552841,
                    0.7667984177770908,
                    0.7393609881148844,
                    0.7544580057165962,
                    0.7397495323890664,
                    0.7658298386950907,
                    0.6611125314078552,
                    0.7977728876184589,
                    0.831090991824215,
                    0.6982347705041605,
                    0.6312226759947304,
                    0.6482907496231076,
                    0.7658362595299695,
                    0.7518400125995974,
                    0.8025480920087656,
                    0.7461501494015976,
                    0.7239149641844659,
                    0.7090055516721123,
                    0.6846622430587006,
                    0.7112212118684095,
                    0.6794460508450644,
                    0.8344462829585355,
                    0.7638782230495382,
                    0.6558255308015893,
                    0.6799836520005059,
                    0.7148088830660722,
                    0.7658007439571484,
                    0.6581857665503666,
                    0.648734015773256,
                    0.8156725527309459,
                    0.6929202590284855,
                    0.7490919589112273,
                    0.7090723359022927,
                    0.7105572886194162,
                    0.7461374422467866,
                    0.7084742440808511,
                    0.6889378818898818,
                    0.7551565238112727,
                    0.7198789279593328,
                    0.7270482774940944,
                    0.6971190427893721,
                    0.7391610904601401,
                    0.6344734499604212,
                    0.6719507302378157,
                    0.6861159059531602,
                    0.6516118314317232,
                    0.7199095105818035,
                    0.6817881072485698,
                    0.7207373940156253,
                    0.7745467156569463,
                    0.6258059783246434,
                    0.7481056513643776,
                    0.7183327989261993,
                    0.7624705969064652,
                    0.703717229048042,
                    0.7487365873620485,
                    0.7495555007485976,
                    0.6409624588579408,
                    0.7176245775200687,
                    0.7537100520717849,
                    0.7569868075002099,
                    0.7392135510982174,
                    0.7188471960732984,
                    0.697413550280765,
                    0.7138614496887067,
                    0.7057641350396069,
                    0.7675079665142936,
                    0.7310427541091186,
                    0.7808018818735418,
                    0.7567255895826445,
                    0.7035962262766099,
                    0.7040750813384501,
                    0.8183159919038065,
                    0.7953911933398697,
                    0.7464891038038547,
                    0.6751591827598264,
                    0.7849943676377955,
                    0.7155963442284841,
                    0.7428993249606122,
                    0.7131100645054201,
                    0.7227595311429733,
                    0.6519548345531954,
                    0.7201522118536183,
                    0.86540799716664,
                    0.8128819241371503,
                    0.7278912446692862,
                    0.7305867175950502,
                    0.7171875192153516,
                    0.6755179003509543,
                    0.7256221359402913,
                    0.7003814947129137,
                    0.7486334697158199,
                    0.7232489166529666,
                    0.7347697330652992,
                    0.6493837702986368,
                    0.6454310256268904,
                    0.7085305859966166,
                    0.7709963397003181,
                    0.7628122486461532,
                    0.7260869667056613,
                    0.7656074896314918,
                    0.7309944223135776,
                    0.7575162043117213,
                    0.7425954755181217,
                    0.7978452334414571,
                    0.7414129597626153,
                    0.7369987033441427,
                    0.7249664966482501,
                    0.663939118162477,
                    0.7490232329485363,
                    0.7532303509685747,
                    0.6505502824396713,
                    0.6820403873171862,
                    0.7458589415356044,
                    0.65106761846338,
                    0.8190794585362886,
                    0.6404595320063431,
                    0.7620011212588133,
                    0.6793344580417779,
                    0.7470455239529016,
                    0.6254743025101126,
                    0.714021296346165,
                    0.7624376857959331,
                    0.6124088443110871,
                    0.7190909082546953,
                    0.7667977482791689,
                    0.7390282391537781,
                    0.731411776312689,
                    0.6910654011387792,
                    0.7669731150300952,
                    0.7473599833172766,
                    0.7757742306046117,
                    0.7524096654164605,
                    0.668078105815945,
                    0.6797298899209671,
                    0.734572374544983,
                    0.7851097143298676,
                    0.7342031538272321,
                    0.7372538892798539,
                    0.7335209046034747,
                    0.6838366965805461,
                    0.6892908218001774,
                    0.7368799079039347,
                    0.667817778038092,
                    0.7436021083467266,
                    0.643687287978847,
                    0.6459012534104801,
                    0.717961524900317,
                    0.651811280493696,
                    0.774663511332172,
                    0.6481450536649574,
                    0.7135017154296592,
                    0.718243043672562,
                    0.6840974559559992,
                    0.6237039667424505,
                    0.7511301324631847,
                    0.6731554748777204,
                    0.7311433676647258,
                    0.7407740108803432,
                    0.7219713524118947,
                    0.6945284597258313,
                    0.7403964086997485,
                    0.7281416758951161,
                    0.767616849591671,
                    0.7623013461443361,
                    0.794278871148869,
                    0.7094595086377886,
                    0.6363511838820268,
                    0.6955392554405656,
                    0.7448982185809608,
                    0.7328610107493276,
                    0.7208624134565648,
                    0.6762963031433926,
                    0.7755406771546928,
                    0.7045876515082797,
                    0.6244682388287953,
                    0.6742169930443296,
                    0.8182351707309709,
                    0.7329206562104693,
                    0.7750198586773644,
                    0.7686712995024062,
                    0.7382706738683358,
                    0.6670365140953741,
                    0.7122098843228497,
                    0.720363344417827,
                    0.7260325476617868,
                    0.7455849615803621,
                    0.7135971488970759,
                    0.7597698332957011,
                    0.7261113201174606,
                    0.7802411718313292,
                    0.6937507195359754,
                    0.7842882648198792,
                    0.6900501446041246,
                    0.760860218311663,
                    0.7134088358525988,
                    0.7053629634417926
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.698216609902851,
                    0.7175312484264973,
                    0.7186004945024,
                    0.8000305631283213,
                    0.6982596730429885,
                    0.7632305498724974,
                    0.7138465594512551,
                    0.6788567152998355,
                    0.7373203069430523,
                    0.6873036529619025,
                    0.6503465541274354,
                    0.7365009792588864,
                    0.741093678604772,
                    0.735107851927508,
                    0.7789882788330387,
                    0.6997359150722298,
                    0.7996799028712657,
                    0.7467428244049107,
                    0.6687862526227949,
                    0.743064589979991,
                    0.8567601871608912,
                    0.7518934598338161,
                    0.7026922402796603,
                    0.7211365080382613,
                    0.706354045904417,
                    0.779427003097291,
                    0.7362962393365906,
                    0.7291751132805836,
                    0.7122378769301289,
                    0.7140477948952889,
                    0.7173405960056081,
                    0.7634875929311733,
                    0.7977581390590794,
                    0.7301463738402032,
                    0.7615332057968679,
                    0.682227985030492,
                    0.7334635989895589,
                    0.7386028577038499,
                    0.659084291835728,
                    0.7899820114980755,
                    0.7247172517704278,
                    0.7438155210487466,
                    0.7005346860296618,
                    0.6648553956533266,
                    0.7701566644966253,
                    0.7514961574415904,
                    0.7587991656983686,
                    0.7001882273133521,
                    0.6910707516646375,
                    0.7394355361240693,
                    0.7276824899179835,
                    0.6759744362016779,
                    0.8302185470787163,
                    0.6928641094502374,
                    0.7120538723839331,
                    0.7224960785372221,
                    0.7198045816069757,
                    0.6813558762031737,
                    0.7165801628045559,
                    0.7120832723046919,
                    0.8420619371167495,
                    0.946387336435492,
                    0.7554566849916029,
                    0.6539169025880401,
                    0.7809846972343978,
                    0.7724403150471626,
                    0.7005276086814838,
                    0.6393757830582598,
                    0.8206678516495857,
                    0.7220887623700465,
                    0.6457268309147112,
                    0.7355556783165726,
                    0.7154485704709247,
                    0.736485552747626,
                    0.6279868336962336,
                    0.6826307018159569,
                    0.6893086023402188,
                    0.6662701662224271,
                    0.7867533724923507,
                    0.7767747518359883,
                    0.8265509786609969,
                    0.7191298707058883,
                    0.7022356632617615,
                    0.7327905404619434,
                    0.744068997548466,
                    0.7610196098150734,
                    0.7115456073990181,
                    0.7432332956176703,
                    0.7893785433034154,
                    0.7290401089931655,
                    0.6577253300110991,
                    0.7003577024217508,
                    0.6656632326716906,
                    0.777774068734494,
                    0.7332058736487923,
                    0.6832255453323396,
                    0.7959026283113317,
                    0.7447573588360193,
                    0.6641622569109379,
                    0.6536973982676278,
                    0.6665423422110455,
                    0.6839542281076658,
                    0.7423443203765382,
                    0.7548386221527327,
                    0.6027110439271006,
                    0.6910860362919979,
                    0.6562817652314046,
                    0.716178983425107,
                    0.724403115068019,
                    0.7259909803954812,
                    0.7571530348849562,
                    0.733442357026098,
                    0.7422263713152575,
                    0.7194490356714829,
                    0.8324274302968815,
                    0.7326854438693811,
                    0.6965534375434462,
                    0.6112368704742529,
                    0.7304384905441154,
                    0.7720040972080942,
                    0.6935032899149339,
                    0.6744196671055273,
                    0.7147585872166342,
                    0.7752945283086001,
                    0.7247252468203336,
                    0.7487128692844189,
                    0.7073695766484023,
                    0.7770002807438351,
                    0.6765039364068596,
                    0.6337555851438214,
                    0.7894145395726685,
                    0.808127657824718,
                    0.6806212157836168,
                    0.6631556694778886,
                    0.7026864982467919,
                    0.6616453779258454,
                    0.7884566439105933,
                    0.7066759939380831,
                    0.6337579382878964,
                    0.6487574469000645,
                    0.772145520185077,
                    0.7232462166936153,
                    0.7901145921506159,
                    0.6972515066616096,
                    0.7122390390469722,
                    0.7192738985202007,
                    0.6592857746430886,
                    0.7609468966795828,
                    0.6603956483288737,
                    0.7457664593808938,
                    0.6826889303879785,
                    0.7934081087742163,
                    0.6763383894378445,
                    0.6839731303810913,
                    0.6687692061586001,
                    0.6991702352595418,
                    0.7797875942740597,
                    0.7663428993137384,
                    0.7674609664886143,
                    0.7252797254797089,
                    0.6976652869077468,
                    0.7158816756428659,
                    0.718241456712601,
                    0.7632800029357677,
                    0.7220393239650442,
                    0.7734594213360495,
                    0.6831354341128338,
                    0.8050740585092945,
                    0.8304016729411579,
                    0.6942297180429666,
                    0.7777210539532041,
                    0.7971492766742209,
                    0.7551377672686813,
                    0.756111346085797,
                    0.7377139073001521,
                    0.7292914327831885,
                    0.6619852966467218,
                    0.6633387313486273,
                    0.7753943622795527,
                    0.6931366760284077,
                    0.6879938635549627,
                    0.7934277451864097,
                    0.7095961561996172,
                    0.6721647904665006,
                    0.6639348167332335,
                    0.7190874751718003,
                    0.6487682662731885,
                    0.7712237228183192,
                    0.7195541308325332,
                    0.7624245070018526,
                    0.7066568592895756,
                    0.6955819267956074,
                    0.762644689565747,
                    0.696550099953965,
                    0.72160309057877,
                    0.6589853583691466,
                    0.7781076723886485,
                    0.7844353288099747,
                    0.6499942545071061,
                    0.7586643115109818,
                    0.7851245713510661,
                    0.6825431110733673,
                    0.7920473550917971,
                    0.7505505200683399,
                    0.7112992195413225,
                    0.6872424297540068,
                    0.6629403755138552,
                    0.7754417757832819,
                    0.7445419843378304,
                    0.7064660255551196,
                    0.7102764764345906,
                    0.7166584523700176,
                    0.745706231193076,
                    0.7628022956052315,
                    0.6960882904963708,
                    0.6837468719098787,
                    0.7520487263229376,
                    0.7604129787986132,
                    0.7522054011542562,
                    0.6898973175025894,
                    0.712053903439596,
                    0.7339122116041967,
                    0.6789458999380491,
                    0.763449576758855,
                    0.7406926320689465,
                    0.6976029213628422,
                    0.7734670110437211,
                    0.7670042286239444,
                    0.7194794729796166,
                    0.8039337600845294,
                    0.822887238017454,
                    0.7546355748553022,
                    0.712175872919759,
                    0.6283408438534046,
                    0.7277722488820856,
                    0.7848289730316366,
                    0.6470175548703326,
                    0.7175970646556328,
                    0.6982508276209234,
                    0.6931585425755658,
                    0.7105706262503181,
                    0.6541269887120206,
                    0.8404400660965669,
                    0.7278163563567496,
                    0.8022018950050529,
                    0.7449136144274442,
                    0.7254549036563473,
                    0.7708530061010334,
                    0.7226568414320923,
                    0.7174427406313559,
                    0.7242867487973796,
                    0.7465548093208114,
                    0.6481248100208876,
                    0.715420475795481,
                    0.7818995378198694,
                    0.710140386310007,
                    0.779070581516581,
                    0.7022765286826387,
                    0.7920014977874291,
                    0.7028249663680173,
                    0.7464011128583546,
                    0.6874796697586977,
                    0.7834259382246206,
                    0.7487992683674399,
                    0.6256280566008573,
                    0.7248416704075088,
                    0.6787298488859722,
                    0.7604099689852636,
                    0.7563309422531382,
                    0.6536692254847523,
                    0.7277402265583088,
                    0.7961681595257355,
                    0.7183023940359037,
                    0.8578324194628058,
                    0.6890352710148969,
                    0.711616174722821,
                    0.6560825239577808,
                    0.7235723022023411,
                    0.6649236563739442,
                    0.6800589793852263,
                    0.7785177771732936,
                    0.8277144895457349,
                    0.7047472917613488,
                    0.6981993581133783,
                    0.69214194925211,
                    0.7175225477364914,
                    0.6821700037384272,
                    0.6934882153010823,
                    0.6671459724713757,
                    0.7577056235720239,
                    0.7452347481085622,
                    0.7540647847248615,
                    0.7623045528688165,
                    0.8419961516314336,
                    0.7607631404114222,
                    0.7104592749279437,
                    0.7907219223857475,
                    0.6626370861321504,
                    0.7293323972767943,
                    0.6747790790615374,
                    0.7205393564241827,
                    0.7182141925188444,
                    0.6698620455142402,
                    0.7774615433926413,
                    0.68441903533162,
                    0.7195176784475413,
                    0.7765542578440102,
                    0.7653003235671018,
                    0.6588957811563716,
                    0.7049538466814345,
                    0.6767019827253833,
                    0.6852115350974048,
                    0.7159808946533304,
                    0.6275008181698795,
                    0.6641598464064111,
                    0.7653064009797307,
                    0.7846062245731015,
                    0.7131195190890488,
                    0.7388407888274415,
                    0.7078575690975646,
                    0.7922969773693673,
                    0.6399205949452071,
                    0.7522331808600956,
                    0.756127025845852,
                    0.7527950868321375,
                    0.7791392496558245,
                    0.7388745760013306,
                    0.6739605493576779,
                    0.6432673241279167,
                    0.7124181751534668,
                    0.669456709871883,
                    0.7067049471522552,
                    0.6685698115209102,
                    0.7430777123010961,
                    0.7510627360284545
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                }
              ],
              "layout": {
                "annotations": [
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=test",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.2425,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=train",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.7575000000000001,
                    "yanchor": "middle",
                    "yref": "paper"
                  }
                ],
                "barmode": "overlay",
                "legend": {
                  "title": {
                    "text": "label"
                  },
                  "tracegroupgap": 0
                },
                "margin": {
                  "t": 60
                },
                "template": {
                  "data": {
                    "bar": [
                      {
                        "error_x": {
                          "color": "#2a3f5f"
                        },
                        "error_y": {
                          "color": "#2a3f5f"
                        },
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "bar"
                      }
                    ],
                    "barpolar": [
                      {
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "barpolar"
                      }
                    ],
                    "carpet": [
                      {
                        "aaxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "baxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "type": "carpet"
                      }
                    ],
                    "choropleth": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "choropleth"
                      }
                    ],
                    "contour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "contour"
                      }
                    ],
                    "contourcarpet": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "contourcarpet"
                      }
                    ],
                    "heatmap": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmap"
                      }
                    ],
                    "heatmapgl": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmapgl"
                      }
                    ],
                    "histogram": [
                      {
                        "marker": {
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "histogram"
                      }
                    ],
                    "histogram2d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2d"
                      }
                    ],
                    "histogram2dcontour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2dcontour"
                      }
                    ],
                    "mesh3d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "mesh3d"
                      }
                    ],
                    "parcoords": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "parcoords"
                      }
                    ],
                    "pie": [
                      {
                        "automargin": true,
                        "type": "pie"
                      }
                    ],
                    "scatter": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter"
                      }
                    ],
                    "scatter3d": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter3d"
                      }
                    ],
                    "scattercarpet": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattercarpet"
                      }
                    ],
                    "scattergeo": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergeo"
                      }
                    ],
                    "scattergl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergl"
                      }
                    ],
                    "scattermapbox": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattermapbox"
                      }
                    ],
                    "scatterpolar": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolar"
                      }
                    ],
                    "scatterpolargl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolargl"
                      }
                    ],
                    "scatterternary": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterternary"
                      }
                    ],
                    "surface": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "surface"
                      }
                    ],
                    "table": [
                      {
                        "cells": {
                          "fill": {
                            "color": "#EBF0F8"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "header": {
                          "fill": {
                            "color": "#C8D4E3"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "type": "table"
                      }
                    ]
                  },
                  "layout": {
                    "annotationdefaults": {
                      "arrowcolor": "#2a3f5f",
                      "arrowhead": 0,
                      "arrowwidth": 1
                    },
                    "autotypenumbers": "strict",
                    "coloraxis": {
                      "colorbar": {
                        "outlinewidth": 0,
                        "ticks": ""
                      }
                    },
                    "colorscale": {
                      "diverging": [
                        [
                          0,
                          "#8e0152"
                        ],
                        [
                          0.1,
                          "#c51b7d"
                        ],
                        [
                          0.2,
                          "#de77ae"
                        ],
                        [
                          0.3,
                          "#f1b6da"
                        ],
                        [
                          0.4,
                          "#fde0ef"
                        ],
                        [
                          0.5,
                          "#f7f7f7"
                        ],
                        [
                          0.6,
                          "#e6f5d0"
                        ],
                        [
                          0.7,
                          "#b8e186"
                        ],
                        [
                          0.8,
                          "#7fbc41"
                        ],
                        [
                          0.9,
                          "#4d9221"
                        ],
                        [
                          1,
                          "#276419"
                        ]
                      ],
                      "sequential": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ],
                      "sequentialminus": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ]
                    },
                    "colorway": [
                      "#636efa",
                      "#EF553B",
                      "#00cc96",
                      "#ab63fa",
                      "#FFA15A",
                      "#19d3f3",
                      "#FF6692",
                      "#B6E880",
                      "#FF97FF",
                      "#FECB52"
                    ],
                    "font": {
                      "color": "#2a3f5f"
                    },
                    "geo": {
                      "bgcolor": "white",
                      "lakecolor": "white",
                      "landcolor": "#E5ECF6",
                      "showlakes": true,
                      "showland": true,
                      "subunitcolor": "white"
                    },
                    "hoverlabel": {
                      "align": "left"
                    },
                    "hovermode": "closest",
                    "mapbox": {
                      "style": "light"
                    },
                    "paper_bgcolor": "white",
                    "plot_bgcolor": "#E5ECF6",
                    "polar": {
                      "angularaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "radialaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "scene": {
                      "xaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "yaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "zaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      }
                    },
                    "shapedefaults": {
                      "line": {
                        "color": "#2a3f5f"
                      }
                    },
                    "ternary": {
                      "aaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "baxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "caxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "title": {
                      "x": 0.05
                    },
                    "xaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    },
                    "yaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    }
                  }
                },
                "width": 500,
                "xaxis": {
                  "anchor": "y",
                  "domain": [
                    0,
                    0.98
                  ],
                  "title": {
                    "text": "cosine_similarity"
                  }
                },
                "xaxis2": {
                  "anchor": "y2",
                  "domain": [
                    0,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "yaxis": {
                  "anchor": "x",
                  "domain": [
                    0,
                    0.485
                  ],
                  "title": {
                    "text": "count"
                  }
                },
                "yaxis2": {
                  "anchor": "x2",
                  "domain": [
                    0.515,
                    1
                  ],
                  "matches": "y",
                  "title": {
                    "text": "count"
                  }
                }
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "train accuracy: 89.1% ± 2.4%\n",
            "test accuracy: 88.8% ± 2.4%\n"
          ]
        }
      ],
      "source": [
        "# calculate accuracy (and its standard error) of predicting label=1 if similarity>x\n",
        "# x is optimized by sweeping from -1 to 1 in steps of 0.01\n",
        "def accuracy_and_se(cosine_similarity: float, labeled_similarity: int) -> Tuple[float]:\n",
        "    accuracies = []\n",
        "    for threshold_thousandths in range(-1000, 1000, 1):\n",
        "        threshold = threshold_thousandths / 1000\n",
        "        total = 0\n",
        "        correct = 0\n",
        "        for cs, ls in zip(cosine_similarity, labeled_similarity):\n",
        "            total += 1\n",
        "            if cs > threshold:\n",
        "                prediction = 1\n",
        "            else:\n",
        "                prediction = -1\n",
        "            if prediction == ls:\n",
        "                correct += 1\n",
        "        accuracy = correct / total\n",
        "        accuracies.append(accuracy)\n",
        "    a = max(accuracies)\n",
        "    n = len(cosine_similarity)\n",
        "    standard_error = (a * (1 - a) / n) ** 0.5  # standard error of binomial\n",
        "    return a, standard_error\n",
        "\n",
        "\n",
        "# check that training and test sets are balanced\n",
        "px.histogram(\n",
        "    df,\n",
        "    x=\"cosine_similarity\",\n",
        "    color=\"label\",\n",
        "    barmode=\"overlay\",\n",
        "    width=500,\n",
        "    facet_row=\"dataset\",\n",
        ").show()\n",
        "\n",
        "for dataset in [\"train\", \"test\"]:\n",
        "    data = df[df[\"dataset\"] == dataset]\n",
        "    a, se = accuracy_and_se(data[\"cosine_similarity\"], data[\"label\"])\n",
        "    print(f\"{dataset} accuracy: {a:0.1%} ± {1.96 * se:0.1%}\")\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "zHLxlnsApgkR"
      },
      "source": [
        "## 7. Optimize the matrix using the training data provided"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "id": "z52V0x8IpgkR"
      },
      "outputs": [],
      "source": [
        "def embedding_multiplied_by_matrix(\n",
        "    embedding: List[float], matrix: torch.tensor\n",
        ") -> np.array:\n",
        "    embedding_tensor = torch.tensor(embedding).float()\n",
        "    modified_embedding = embedding_tensor @ matrix\n",
        "    modified_embedding = modified_embedding.detach().numpy()\n",
        "    return modified_embedding\n",
        "\n",
        "\n",
        "# compute custom embeddings and new cosine similarities\n",
        "def apply_matrix_to_embeddings_dataframe(matrix: torch.tensor, df: pd.DataFrame):\n",
        "    for column in [\"text_1_embedding\", \"text_2_embedding\"]:\n",
        "        df[f\"{column}_custom\"] = df[column].apply(\n",
        "            lambda x: embedding_multiplied_by_matrix(x, matrix)\n",
        "        )\n",
        "    df[\"cosine_similarity_custom\"] = df.apply(\n",
        "        lambda row: cosine_similarity(\n",
        "            row[\"text_1_embedding_custom\"], row[\"text_2_embedding_custom\"]\n",
        "        ),\n",
        "        axis=1,\n",
        "    )\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "id": "p2ZSXu6spgkR"
      },
      "outputs": [],
      "source": [
        "def optimize_matrix(\n",
        "    modified_embedding_length: int = 2048,  # in my brief experimentation, bigger was better (2048 is length of babbage encoding)\n",
        "    batch_size: int = 100,\n",
        "    max_epochs: int = 100,\n",
        "    learning_rate: float = 100.0,  # seemed to work best when similar to batch size - feel free to try a range of values\n",
        "    dropout_fraction: float = 0.0,  # in my testing, dropout helped by a couple percentage points (definitely not necessary)\n",
        "    df: pd.DataFrame = df,\n",
        "    print_progress: bool = True,\n",
        "    save_results: bool = True,\n",
        ") -> torch.tensor:\n",
        "    \"\"\"Return matrix optimized to minimize loss on training data.\"\"\"\n",
        "    run_id = random.randint(0, 2 ** 31 - 1)  # (range is arbitrary)\n",
        "    # convert from dataframe to torch tensors\n",
        "    # e is for embedding, s for similarity label\n",
        "    def tensors_from_dataframe(\n",
        "        df: pd.DataFrame,\n",
        "        embedding_column_1: str,\n",
        "        embedding_column_2: str,\n",
        "        similarity_label_column: str,\n",
        "    ) -> Tuple[torch.tensor]:\n",
        "        e1 = np.stack(np.array(df[embedding_column_1].values))\n",
        "        e2 = np.stack(np.array(df[embedding_column_2].values))\n",
        "        s = np.stack(np.array(df[similarity_label_column].astype(\"float\").values))\n",
        "\n",
        "        e1 = torch.from_numpy(e1).float()\n",
        "        e2 = torch.from_numpy(e2).float()\n",
        "        s = torch.from_numpy(s).float()\n",
        "\n",
        "        return e1, e2, s\n",
        "\n",
        "    e1_train, e2_train, s_train = tensors_from_dataframe(\n",
        "        df[df[\"dataset\"] == \"train\"], \"text_1_embedding\", \"text_2_embedding\", \"label\"\n",
        "    )\n",
        "    e1_test, e2_test, s_test = tensors_from_dataframe(\n",
        "        df[df[\"dataset\"] == \"test\"], \"text_1_embedding\", \"text_2_embedding\", \"label\"\n",
        "    )\n",
        "\n",
        "    # create dataset and loader\n",
        "    dataset = torch.utils.data.TensorDataset(e1_train, e2_train, s_train)\n",
        "    train_loader = torch.utils.data.DataLoader(\n",
        "        dataset, batch_size=batch_size, shuffle=True\n",
        "    )\n",
        "\n",
        "    # define model (similarity of projected embeddings)\n",
        "    def model(embedding_1, embedding_2, matrix, dropout_fraction=dropout_fraction):\n",
        "        e1 = torch.nn.functional.dropout(embedding_1, p=dropout_fraction)\n",
        "        e2 = torch.nn.functional.dropout(embedding_2, p=dropout_fraction)\n",
        "        modified_embedding_1 = e1 @ matrix  # @ is matrix multiplication\n",
        "        modified_embedding_2 = e2 @ matrix\n",
        "        similarity = torch.nn.functional.cosine_similarity(\n",
        "            modified_embedding_1, modified_embedding_2\n",
        "        )\n",
        "        return similarity\n",
        "\n",
        "    # define loss function to minimize\n",
        "    def mse_loss(predictions, targets):\n",
        "        difference = predictions - targets\n",
        "        return torch.sum(difference * difference) / difference.numel()\n",
        "\n",
        "    # initialize projection matrix\n",
        "    embedding_length = len(df[\"text_1_embedding\"].values[0])\n",
        "    matrix = torch.randn(\n",
        "        embedding_length, modified_embedding_length, requires_grad=True\n",
        "    )\n",
        "\n",
        "    epochs, types, losses, accuracies, matrices = [], [], [], [], []\n",
        "    for epoch in range(1, 1 + max_epochs):\n",
        "        # iterate through training dataloader\n",
        "        for a, b, actual_similarity in train_loader:\n",
        "            # generate prediction\n",
        "            predicted_similarity = model(a, b, matrix)\n",
        "            # get loss and perform backpropagation\n",
        "            loss = mse_loss(predicted_similarity, actual_similarity)\n",
        "            loss.backward()\n",
        "            # update the weights\n",
        "            with torch.no_grad():\n",
        "                matrix -= matrix.grad * learning_rate\n",
        "                # set gradients to zero\n",
        "                matrix.grad.zero_()\n",
        "        # calculate test loss\n",
        "        test_predictions = model(e1_test, e2_test, matrix)\n",
        "        test_loss = mse_loss(test_predictions, s_test)\n",
        "\n",
        "        # compute custom embeddings and new cosine similarities\n",
        "        apply_matrix_to_embeddings_dataframe(matrix, df)\n",
        "\n",
        "        # calculate test accuracy\n",
        "        for dataset in [\"train\", \"test\"]:\n",
        "            data = df[df[\"dataset\"] == dataset]\n",
        "            a, se = accuracy_and_se(data[\"cosine_similarity_custom\"], data[\"label\"])\n",
        "\n",
        "            # record results of each epoch\n",
        "            epochs.append(epoch)\n",
        "            types.append(dataset)\n",
        "            losses.append(loss.item() if dataset == \"train\" else test_loss.item())\n",
        "            accuracies.append(a)\n",
        "            matrices.append(matrix.detach().numpy())\n",
        "\n",
        "            # optionally print accuracies\n",
        "            if print_progress is True:\n",
        "                print(\n",
        "                    f\"Epoch {epoch}/{max_epochs}: {dataset} accuracy: {a:0.1%} ± {1.96 * se:0.1%}\"\n",
        "                )\n",
        "\n",
        "    data = pd.DataFrame(\n",
        "        {\"epoch\": epochs, \"type\": types, \"loss\": losses, \"accuracy\": accuracies}\n",
        "    )\n",
        "    data[\"run_id\"] = run_id\n",
        "    data[\"modified_embedding_length\"] = modified_embedding_length\n",
        "    data[\"batch_size\"] = batch_size\n",
        "    data[\"max_epochs\"] = max_epochs\n",
        "    data[\"learning_rate\"] = learning_rate\n",
        "    data[\"dropout_fraction\"] = dropout_fraction\n",
        "    data[\n",
        "        \"matrix\"\n",
        "    ] = matrices  # saving every single matrix can get big; feel free to delete/change\n",
        "    if save_results is True:\n",
        "        data.to_csv(f\"{run_id}_optimization_results.csv\", index=False)\n",
        "\n",
        "    return data\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "id": "nlcUW-zEpgkS",
        "outputId": "4bd4bdff-628a-406f-fffe-aedbfad66446"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Epoch 1/30: train accuracy: 89.1% ± 2.4%\n",
            "Epoch 1/30: test accuracy: 88.4% ± 2.4%\n",
            "Epoch 2/30: train accuracy: 89.5% ± 2.3%\n",
            "Epoch 2/30: test accuracy: 88.8% ± 2.4%\n",
            "Epoch 3/30: train accuracy: 90.6% ± 2.2%\n",
            "Epoch 3/30: test accuracy: 89.3% ± 2.3%\n",
            "Epoch 4/30: train accuracy: 91.2% ± 2.2%\n",
            "Epoch 4/30: test accuracy: 89.7% ± 2.3%\n",
            "Epoch 5/30: train accuracy: 91.5% ± 2.1%\n",
            "Epoch 5/30: test accuracy: 90.0% ± 2.3%\n",
            "Epoch 6/30: train accuracy: 91.9% ± 2.1%\n",
            "Epoch 6/30: test accuracy: 90.4% ± 2.2%\n",
            "Epoch 7/30: train accuracy: 92.2% ± 2.0%\n",
            "Epoch 7/30: test accuracy: 90.7% ± 2.2%\n",
            "Epoch 8/30: train accuracy: 92.7% ± 2.0%\n",
            "Epoch 8/30: test accuracy: 90.9% ± 2.2%\n",
            "Epoch 9/30: train accuracy: 92.7% ± 2.0%\n",
            "Epoch 9/30: test accuracy: 91.0% ± 2.2%\n",
            "Epoch 10/30: train accuracy: 93.0% ± 1.9%\n",
            "Epoch 10/30: test accuracy: 91.6% ± 2.1%\n",
            "Epoch 11/30: train accuracy: 93.1% ± 1.9%\n",
            "Epoch 11/30: test accuracy: 91.8% ± 2.1%\n",
            "Epoch 12/30: train accuracy: 93.4% ± 1.9%\n",
            "Epoch 12/30: test accuracy: 92.1% ± 2.0%\n",
            "Epoch 13/30: train accuracy: 93.6% ± 1.9%\n",
            "Epoch 13/30: test accuracy: 92.4% ± 2.0%\n",
            "Epoch 14/30: train accuracy: 93.7% ± 1.8%\n",
            "Epoch 14/30: test accuracy: 92.7% ± 2.0%\n",
            "Epoch 15/30: train accuracy: 93.7% ± 1.8%\n",
            "Epoch 15/30: test accuracy: 92.7% ± 2.0%\n",
            "Epoch 16/30: train accuracy: 94.0% ± 1.8%\n",
            "Epoch 16/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 17/30: train accuracy: 94.0% ± 1.8%\n",
            "Epoch 17/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 18/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 18/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 19/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 19/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 20/30: train accuracy: 94.3% ± 1.8%\n",
            "Epoch 20/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 21/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 21/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 22/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 22/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 23/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 23/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 24/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 24/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 25/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 25/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 26/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 26/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 27/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 27/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 28/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 28/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 29/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 29/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 30/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 30/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 1/30: train accuracy: 89.7% ± 2.3%\n",
            "Epoch 1/30: test accuracy: 89.1% ± 2.4%\n",
            "Epoch 2/30: train accuracy: 89.8% ± 2.3%\n",
            "Epoch 2/30: test accuracy: 89.9% ± 2.3%\n",
            "Epoch 3/30: train accuracy: 90.3% ± 2.2%\n",
            "Epoch 3/30: test accuracy: 90.0% ± 2.3%\n",
            "Epoch 4/30: train accuracy: 91.0% ± 2.2%\n",
            "Epoch 4/30: test accuracy: 90.3% ± 2.2%\n",
            "Epoch 5/30: train accuracy: 91.3% ± 2.1%\n",
            "Epoch 5/30: test accuracy: 90.3% ± 2.2%\n",
            "Epoch 6/30: train accuracy: 91.8% ± 2.1%\n",
            "Epoch 6/30: test accuracy: 90.4% ± 2.2%\n",
            "Epoch 7/30: train accuracy: 92.4% ± 2.0%\n",
            "Epoch 7/30: test accuracy: 91.0% ± 2.2%\n",
            "Epoch 8/30: train accuracy: 92.8% ± 2.0%\n",
            "Epoch 8/30: test accuracy: 91.3% ± 2.1%\n",
            "Epoch 9/30: train accuracy: 93.1% ± 1.9%\n",
            "Epoch 9/30: test accuracy: 91.6% ± 2.1%\n",
            "Epoch 10/30: train accuracy: 93.4% ± 1.9%\n",
            "Epoch 10/30: test accuracy: 91.9% ± 2.1%\n",
            "Epoch 11/30: train accuracy: 93.4% ± 1.9%\n",
            "Epoch 11/30: test accuracy: 91.8% ± 2.1%\n",
            "Epoch 12/30: train accuracy: 93.6% ± 1.9%\n",
            "Epoch 12/30: test accuracy: 92.1% ± 2.0%\n",
            "Epoch 13/30: train accuracy: 93.7% ± 1.8%\n",
            "Epoch 13/30: test accuracy: 92.4% ± 2.0%\n",
            "Epoch 14/30: train accuracy: 93.7% ± 1.8%\n",
            "Epoch 14/30: test accuracy: 92.5% ± 2.0%\n",
            "Epoch 15/30: train accuracy: 93.9% ± 1.8%\n",
            "Epoch 15/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 16/30: train accuracy: 94.0% ± 1.8%\n",
            "Epoch 16/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 17/30: train accuracy: 94.0% ± 1.8%\n",
            "Epoch 17/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 18/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 18/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 19/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 19/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 20/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 20/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 21/30: train accuracy: 94.3% ± 1.8%\n",
            "Epoch 21/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 22/30: train accuracy: 94.3% ± 1.8%\n",
            "Epoch 22/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 23/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 23/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 24/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 24/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 25/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 25/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 26/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 26/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 27/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 27/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 28/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 28/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 29/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 29/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 30/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 30/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 1/30: train accuracy: 90.7% ± 2.2%\n",
            "Epoch 1/30: test accuracy: 89.9% ± 2.3%\n",
            "Epoch 2/30: train accuracy: 90.9% ± 2.2%\n",
            "Epoch 2/30: test accuracy: 90.3% ± 2.2%\n",
            "Epoch 3/30: train accuracy: 91.6% ± 2.1%\n",
            "Epoch 3/30: test accuracy: 90.3% ± 2.2%\n",
            "Epoch 4/30: train accuracy: 92.2% ± 2.0%\n",
            "Epoch 4/30: test accuracy: 90.7% ± 2.2%\n",
            "Epoch 5/30: train accuracy: 92.4% ± 2.0%\n",
            "Epoch 5/30: test accuracy: 91.3% ± 2.1%\n",
            "Epoch 6/30: train accuracy: 92.5% ± 2.0%\n",
            "Epoch 6/30: test accuracy: 91.8% ± 2.1%\n",
            "Epoch 7/30: train accuracy: 93.0% ± 1.9%\n",
            "Epoch 7/30: test accuracy: 92.2% ± 2.0%\n",
            "Epoch 8/30: train accuracy: 93.1% ± 1.9%\n",
            "Epoch 8/30: test accuracy: 92.7% ± 2.0%\n",
            "Epoch 9/30: train accuracy: 93.3% ± 1.9%\n",
            "Epoch 9/30: test accuracy: 92.5% ± 2.0%\n",
            "Epoch 10/30: train accuracy: 93.4% ± 1.9%\n",
            "Epoch 10/30: test accuracy: 92.7% ± 2.0%\n",
            "Epoch 11/30: train accuracy: 93.6% ± 1.9%\n",
            "Epoch 11/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 12/30: train accuracy: 93.7% ± 1.8%\n",
            "Epoch 12/30: test accuracy: 92.8% ± 2.0%\n",
            "Epoch 13/30: train accuracy: 94.0% ± 1.8%\n",
            "Epoch 13/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 14/30: train accuracy: 93.9% ± 1.8%\n",
            "Epoch 14/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 15/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 15/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 16/30: train accuracy: 94.2% ± 1.8%\n",
            "Epoch 16/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 17/30: train accuracy: 94.3% ± 1.8%\n",
            "Epoch 17/30: test accuracy: 93.0% ± 1.9%\n",
            "Epoch 18/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 18/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 19/30: train accuracy: 94.5% ± 1.7%\n",
            "Epoch 19/30: test accuracy: 93.1% ± 1.9%\n",
            "Epoch 20/30: train accuracy: 94.6% ± 1.7%\n",
            "Epoch 20/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 21/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 21/30: test accuracy: 93.3% ± 1.9%\n",
            "Epoch 22/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 22/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 23/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 23/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 24/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 24/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 25/30: train accuracy: 94.8% ± 1.7%\n",
            "Epoch 25/30: test accuracy: 93.4% ± 1.9%\n",
            "Epoch 26/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 26/30: test accuracy: 93.6% ± 1.9%\n",
            "Epoch 27/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 27/30: test accuracy: 93.6% ± 1.9%\n",
            "Epoch 28/30: train accuracy: 94.9% ± 1.7%\n",
            "Epoch 28/30: test accuracy: 93.6% ± 1.9%\n",
            "Epoch 29/30: train accuracy: 95.1% ± 1.6%\n",
            "Epoch 29/30: test accuracy: 93.6% ± 1.9%\n",
            "Epoch 30/30: train accuracy: 95.1% ± 1.6%\n",
            "Epoch 30/30: test accuracy: 93.6% ± 1.9%\n"
          ]
        }
      ],
      "source": [
        "# example hyperparameter search\n",
        "# I recommend starting with max_epochs=10 while initially exploring\n",
        "results = []\n",
        "max_epochs = 30\n",
        "dropout_fraction = 0.2\n",
        "for batch_size, learning_rate in [(10, 10), (100, 100), (1000, 1000)]:\n",
        "    result = optimize_matrix(\n",
        "        batch_size=batch_size,\n",
        "        learning_rate=learning_rate,\n",
        "        max_epochs=max_epochs,\n",
        "        dropout_fraction=dropout_fraction,\n",
        "        save_results=False,\n",
        "    )\n",
        "    results.append(result)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "id": "PoTZWC1SpgkS",
        "outputId": "207360e5-fd07-4180-a143-0ec5dd27ffe1"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.plotly.v1+json": {
              "config": {
                "plotlyServerURL": "https://plot.ly"
              },
              "data": [
                {
                  "customdata": [
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x7",
                  "y": [
                    1.1652076244354248,
                    0.8190054297447205,
                    0.9203463792800903,
                    0.9908844232559204,
                    0.8408004641532898,
                    0.7119297981262207,
                    0.6632728576660156,
                    0.8289280533790588,
                    0.8687358498573303,
                    0.8235021829605103,
                    0.895695149898529,
                    0.6677632331848145,
                    0.7526643872261047,
                    0.7708764672279358,
                    0.68276047706604,
                    0.6613249778747559,
                    0.5960850119590759,
                    0.8617165684700012,
                    0.724422037601471,
                    0.9765143394470215,
                    0.5958823561668396,
                    0.7277706265449524,
                    0.7929649353027344,
                    0.8311190009117126,
                    0.484933465719223,
                    0.6846191883087158,
                    0.6711297035217285,
                    0.738968551158905,
                    0.5267000198364258,
                    0.9111422300338745
                  ],
                  "yaxis": "y7"
                },
                {
                  "customdata": [
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x5",
                  "y": [
                    1.06173574924469,
                    0.9195982813835144,
                    0.8605211973190308,
                    0.7780925631523132,
                    0.7680334448814392,
                    0.7896735072135925,
                    0.7652512788772583,
                    0.7015480399131775,
                    0.8019503951072693,
                    0.7844551801681519,
                    0.823682427406311,
                    0.711807131767273,
                    0.7855805158615112,
                    0.7014225125312805,
                    0.7862630486488342,
                    0.6663534045219421,
                    0.7388879060745239,
                    0.6876973509788513,
                    0.7274147272109985,
                    0.7191041111946106,
                    0.8075127601623535,
                    0.7195712924003601,
                    0.746185839176178,
                    0.7220138311386108,
                    0.7456589341163635,
                    0.6642791032791138,
                    0.7399784326553345,
                    0.7393214106559753,
                    0.6680636405944824,
                    0.6562733054161072
                  ],
                  "yaxis": "y5"
                },
                {
                  "customdata": [
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x3",
                  "y": [
                    1.299358606338501,
                    1.0686416625976562,
                    0.8467883467674255,
                    0.7818496823310852,
                    0.7742239236831665,
                    0.7678887844085693,
                    0.7696077227592468,
                    0.766943097114563,
                    0.7572135925292969,
                    0.7593567967414856,
                    0.7546665668487549,
                    0.7499144077301025,
                    0.7492073178291321,
                    0.7394933700561523,
                    0.743760883808136,
                    0.7366983294487,
                    0.7340478301048279,
                    0.7297782897949219,
                    0.7292298674583435,
                    0.7229472994804382,
                    0.7246285080909729,
                    0.721783459186554,
                    0.7177888751029968,
                    0.7198930978775024,
                    0.7123011946678162,
                    0.7132685780525208,
                    0.7121831178665161,
                    0.7118210196495056,
                    0.7035670280456543,
                    0.7066351771354675
                  ],
                  "yaxis": "y3"
                },
                {
                  "customdata": [
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x7",
                  "y": [
                    1.1514005661010742,
                    0.9815413355827332,
                    0.8687632083892822,
                    0.8124286532402039,
                    0.7918410301208496,
                    0.7831570506095886,
                    0.7777765393257141,
                    0.7735418677330017,
                    0.7701976895332336,
                    0.7720419764518738,
                    0.7726959586143494,
                    0.7650123834609985,
                    0.7650409936904907,
                    0.765683114528656,
                    0.7626248598098755,
                    0.7623012065887451,
                    0.7609940767288208,
                    0.7587862610816956,
                    0.7559080123901367,
                    0.7571383118629456,
                    0.7588285803794861,
                    0.7556950449943542,
                    0.7562108039855957,
                    0.7484216690063477,
                    0.7531804442405701,
                    0.7502257823944092,
                    0.7496891617774963,
                    0.7472137808799744,
                    0.748519241809845,
                    0.7483490705490112
                  ],
                  "yaxis": "y7"
                },
                {
                  "customdata": [
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x5",
                  "y": [
                    1.1180120706558228,
                    0.9514004588127136,
                    0.8425014019012451,
                    0.8078141212463379,
                    0.7853298187255859,
                    0.7790449857711792,
                    0.7733959555625916,
                    0.7792260646820068,
                    0.7700104117393494,
                    0.7700384855270386,
                    0.764447033405304,
                    0.7696305513381958,
                    0.7642591595649719,
                    0.7629777193069458,
                    0.7620665431022644,
                    0.7622931003570557,
                    0.7602745890617371,
                    0.7564830780029297,
                    0.761269748210907,
                    0.7550154328346252,
                    0.7560049295425415,
                    0.7538376450538635,
                    0.7503026127815247,
                    0.7528620958328247,
                    0.7485130429267883,
                    0.7481465339660645,
                    0.7483287453651428,
                    0.742965817451477,
                    0.7445206642150879,
                    0.7476803064346313
                  ],
                  "yaxis": "y5"
                },
                {
                  "customdata": [
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>loss=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x3",
                  "y": [
                    1.0644111633300781,
                    0.8446540236473083,
                    0.783637285232544,
                    0.7773287892341614,
                    0.7771273851394653,
                    0.7754599452018738,
                    0.7732424736022949,
                    0.7710903882980347,
                    0.772106945514679,
                    0.7645512223243713,
                    0.7639281153678894,
                    0.7609775066375732,
                    0.7558433413505554,
                    0.7579177618026733,
                    0.7527595162391663,
                    0.7520467042922974,
                    0.7534357309341431,
                    0.7487133145332336,
                    0.7478086352348328,
                    0.747072160243988,
                    0.7411499619483948,
                    0.7459169030189514,
                    0.7451942563056946,
                    0.7394304275512695,
                    0.7337468862533569,
                    0.734693169593811,
                    0.7376500368118286,
                    0.7371401190757751,
                    0.731780469417572,
                    0.727291464805603
                  ],
                  "yaxis": "y3"
                }
              ],
              "layout": {
                "annotations": [
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=10",
                    "x": 0.15666666666666665,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=100",
                    "x": 0.49,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=1000",
                    "x": 0.8233333333333333,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=1000",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.15666666666666665,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=100",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.4999999999999999,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=10",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.8433333333333332,
                    "yanchor": "middle",
                    "yref": "paper"
                  }
                ],
                "legend": {
                  "title": {
                    "text": "type"
                  },
                  "tracegroupgap": 0
                },
                "margin": {
                  "t": 60
                },
                "template": {
                  "data": {
                    "bar": [
                      {
                        "error_x": {
                          "color": "#2a3f5f"
                        },
                        "error_y": {
                          "color": "#2a3f5f"
                        },
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "bar"
                      }
                    ],
                    "barpolar": [
                      {
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "barpolar"
                      }
                    ],
                    "carpet": [
                      {
                        "aaxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "baxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "type": "carpet"
                      }
                    ],
                    "choropleth": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "choropleth"
                      }
                    ],
                    "contour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "contour"
                      }
                    ],
                    "contourcarpet": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "contourcarpet"
                      }
                    ],
                    "heatmap": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmap"
                      }
                    ],
                    "heatmapgl": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmapgl"
                      }
                    ],
                    "histogram": [
                      {
                        "marker": {
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "histogram"
                      }
                    ],
                    "histogram2d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2d"
                      }
                    ],
                    "histogram2dcontour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2dcontour"
                      }
                    ],
                    "mesh3d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "mesh3d"
                      }
                    ],
                    "parcoords": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "parcoords"
                      }
                    ],
                    "pie": [
                      {
                        "automargin": true,
                        "type": "pie"
                      }
                    ],
                    "scatter": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter"
                      }
                    ],
                    "scatter3d": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter3d"
                      }
                    ],
                    "scattercarpet": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattercarpet"
                      }
                    ],
                    "scattergeo": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergeo"
                      }
                    ],
                    "scattergl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergl"
                      }
                    ],
                    "scattermapbox": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattermapbox"
                      }
                    ],
                    "scatterpolar": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolar"
                      }
                    ],
                    "scatterpolargl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolargl"
                      }
                    ],
                    "scatterternary": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterternary"
                      }
                    ],
                    "surface": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "surface"
                      }
                    ],
                    "table": [
                      {
                        "cells": {
                          "fill": {
                            "color": "#EBF0F8"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "header": {
                          "fill": {
                            "color": "#C8D4E3"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "type": "table"
                      }
                    ]
                  },
                  "layout": {
                    "annotationdefaults": {
                      "arrowcolor": "#2a3f5f",
                      "arrowhead": 0,
                      "arrowwidth": 1
                    },
                    "autotypenumbers": "strict",
                    "coloraxis": {
                      "colorbar": {
                        "outlinewidth": 0,
                        "ticks": ""
                      }
                    },
                    "colorscale": {
                      "diverging": [
                        [
                          0,
                          "#8e0152"
                        ],
                        [
                          0.1,
                          "#c51b7d"
                        ],
                        [
                          0.2,
                          "#de77ae"
                        ],
                        [
                          0.3,
                          "#f1b6da"
                        ],
                        [
                          0.4,
                          "#fde0ef"
                        ],
                        [
                          0.5,
                          "#f7f7f7"
                        ],
                        [
                          0.6,
                          "#e6f5d0"
                        ],
                        [
                          0.7,
                          "#b8e186"
                        ],
                        [
                          0.8,
                          "#7fbc41"
                        ],
                        [
                          0.9,
                          "#4d9221"
                        ],
                        [
                          1,
                          "#276419"
                        ]
                      ],
                      "sequential": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ],
                      "sequentialminus": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ]
                    },
                    "colorway": [
                      "#636efa",
                      "#EF553B",
                      "#00cc96",
                      "#ab63fa",
                      "#FFA15A",
                      "#19d3f3",
                      "#FF6692",
                      "#B6E880",
                      "#FF97FF",
                      "#FECB52"
                    ],
                    "font": {
                      "color": "#2a3f5f"
                    },
                    "geo": {
                      "bgcolor": "white",
                      "lakecolor": "white",
                      "landcolor": "#E5ECF6",
                      "showlakes": true,
                      "showland": true,
                      "subunitcolor": "white"
                    },
                    "hoverlabel": {
                      "align": "left"
                    },
                    "hovermode": "closest",
                    "mapbox": {
                      "style": "light"
                    },
                    "paper_bgcolor": "white",
                    "plot_bgcolor": "#E5ECF6",
                    "polar": {
                      "angularaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "radialaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "scene": {
                      "xaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "yaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "zaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      }
                    },
                    "shapedefaults": {
                      "line": {
                        "color": "#2a3f5f"
                      }
                    },
                    "ternary": {
                      "aaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "baxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "caxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "title": {
                      "x": 0.05
                    },
                    "xaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    },
                    "yaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    }
                  }
                },
                "width": 500,
                "xaxis": {
                  "anchor": "y",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis2": {
                  "anchor": "y2",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis3": {
                  "anchor": "y3",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis4": {
                  "anchor": "y4",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis5": {
                  "anchor": "y5",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis6": {
                  "anchor": "y6",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis7": {
                  "anchor": "y7",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis8": {
                  "anchor": "y8",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis9": {
                  "anchor": "y9",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "yaxis": {
                  "anchor": "x",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "title": {
                    "text": "loss"
                  }
                },
                "yaxis2": {
                  "anchor": "x2",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis3": {
                  "anchor": "x3",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis4": {
                  "anchor": "x4",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "title": {
                    "text": "loss"
                  }
                },
                "yaxis5": {
                  "anchor": "x5",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis6": {
                  "anchor": "x6",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis7": {
                  "anchor": "x7",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "title": {
                    "text": "loss"
                  }
                },
                "yaxis8": {
                  "anchor": "x8",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis9": {
                  "anchor": "x9",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "showticklabels": false
                }
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "data": {
            "application/vnd.plotly.v1+json": {
              "config": {
                "plotlyServerURL": "https://plot.ly"
              },
              "data": [
                {
                  "customdata": [
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x7",
                  "y": [
                    0.8907185628742516,
                    0.8952095808383234,
                    0.905688622754491,
                    0.9116766467065869,
                    0.9146706586826348,
                    0.9191616766467066,
                    0.9221556886227545,
                    0.9266467065868264,
                    0.9266467065868264,
                    0.9296407185628742,
                    0.9311377245508982,
                    0.9341317365269461,
                    0.9356287425149701,
                    0.937125748502994,
                    0.937125748502994,
                    0.9401197604790419,
                    0.9401197604790419,
                    0.9416167664670658,
                    0.9416167664670658,
                    0.9431137724550899,
                    0.9446107784431138,
                    0.9446107784431138,
                    0.9461077844311377,
                    0.9461077844311377,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9491017964071856,
                    0.9491017964071856,
                    0.9491017964071856
                  ],
                  "yaxis": "y7"
                },
                {
                  "customdata": [
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x5",
                  "y": [
                    0.8967065868263473,
                    0.8982035928143712,
                    0.9026946107784432,
                    0.9101796407185628,
                    0.9131736526946108,
                    0.9176646706586826,
                    0.9236526946107785,
                    0.9281437125748503,
                    0.9311377245508982,
                    0.9341317365269461,
                    0.9341317365269461,
                    0.9356287425149701,
                    0.937125748502994,
                    0.937125748502994,
                    0.938622754491018,
                    0.9401197604790419,
                    0.9401197604790419,
                    0.9416167664670658,
                    0.9416167664670658,
                    0.9416167664670658,
                    0.9431137724550899,
                    0.9431137724550899,
                    0.9446107784431138,
                    0.9446107784431138,
                    0.9461077844311377,
                    0.9461077844311377,
                    0.9461077844311377,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9476047904191617
                  ],
                  "yaxis": "y5"
                },
                {
                  "customdata": [
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=train<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "train",
                  "line": {
                    "color": "#636efa",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "train",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x3",
                  "y": [
                    0.907185628742515,
                    0.9086826347305389,
                    0.9161676646706587,
                    0.9221556886227545,
                    0.9236526946107785,
                    0.9251497005988024,
                    0.9296407185628742,
                    0.9311377245508982,
                    0.9326347305389222,
                    0.9341317365269461,
                    0.9356287425149701,
                    0.937125748502994,
                    0.9401197604790419,
                    0.938622754491018,
                    0.9416167664670658,
                    0.9416167664670658,
                    0.9431137724550899,
                    0.9446107784431138,
                    0.9446107784431138,
                    0.9461077844311377,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9476047904191617,
                    0.9491017964071856,
                    0.9491017964071856,
                    0.9491017964071856,
                    0.9505988023952096,
                    0.9505988023952096
                  ],
                  "yaxis": "y3"
                },
                {
                  "customdata": [
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ],
                    [
                      10,
                      10,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=1449308123<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x7",
                  "y": [
                    0.8835820895522388,
                    0.8880597014925373,
                    0.8925373134328358,
                    0.8970149253731343,
                    0.9,
                    0.9044776119402985,
                    0.9074626865671642,
                    0.908955223880597,
                    0.9104477611940298,
                    0.9164179104477612,
                    0.917910447761194,
                    0.9208955223880597,
                    0.9238805970149254,
                    0.926865671641791,
                    0.926865671641791,
                    0.9298507462686567,
                    0.9298507462686567,
                    0.9313432835820895,
                    0.9313432835820895,
                    0.9298507462686567,
                    0.9313432835820895,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9328358208955224
                  ],
                  "yaxis": "y7"
                },
                {
                  "customdata": [
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ],
                    [
                      100,
                      100,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=676326879<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x5",
                  "y": [
                    0.891044776119403,
                    0.8985074626865671,
                    0.9,
                    0.9029850746268657,
                    0.9029850746268657,
                    0.9044776119402985,
                    0.9104477611940298,
                    0.9134328358208955,
                    0.9164179104477612,
                    0.9194029850746268,
                    0.917910447761194,
                    0.9208955223880597,
                    0.9238805970149254,
                    0.9253731343283582,
                    0.9283582089552239,
                    0.9283582089552239,
                    0.9283582089552239,
                    0.9283582089552239,
                    0.9283582089552239,
                    0.9313432835820895,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9343283582089552,
                    0.9328358208955224,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9328358208955224,
                    0.9343283582089552
                  ],
                  "yaxis": "y5"
                },
                {
                  "customdata": [
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ],
                    [
                      1000,
                      1000,
                      0.2
                    ]
                  ],
                  "hovertemplate": "type=test<br>learning_rate=%{customdata[1]}<br>batch_size=%{customdata[0]}<br>run_id=881033356<br>epoch=%{x}<br>accuracy=%{y}<br>dropout_fraction=%{customdata[2]}<extra></extra>",
                  "legendgroup": "test",
                  "line": {
                    "color": "#EF553B",
                    "dash": "solid"
                  },
                  "marker": {
                    "symbol": "circle"
                  },
                  "mode": "lines",
                  "name": "test",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "scatter",
                  "x": [
                    1,
                    2,
                    3,
                    4,
                    5,
                    6,
                    7,
                    8,
                    9,
                    10,
                    11,
                    12,
                    13,
                    14,
                    15,
                    16,
                    17,
                    18,
                    19,
                    20,
                    21,
                    22,
                    23,
                    24,
                    25,
                    26,
                    27,
                    28,
                    29,
                    30
                  ],
                  "xaxis": "x3",
                  "y": [
                    0.8985074626865671,
                    0.9029850746268657,
                    0.9029850746268657,
                    0.9074626865671642,
                    0.9134328358208955,
                    0.917910447761194,
                    0.9223880597014925,
                    0.926865671641791,
                    0.9253731343283582,
                    0.926865671641791,
                    0.9283582089552239,
                    0.9283582089552239,
                    0.9298507462686567,
                    0.9298507462686567,
                    0.9298507462686567,
                    0.9298507462686567,
                    0.9298507462686567,
                    0.9313432835820895,
                    0.9313432835820895,
                    0.9328358208955224,
                    0.9328358208955224,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.9343283582089552,
                    0.935820895522388,
                    0.935820895522388,
                    0.935820895522388,
                    0.935820895522388,
                    0.935820895522388
                  ],
                  "yaxis": "y3"
                }
              ],
              "layout": {
                "annotations": [
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=10",
                    "x": 0.15666666666666665,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=100",
                    "x": 0.49,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "batch_size=1000",
                    "x": 0.8233333333333333,
                    "xanchor": "center",
                    "xref": "paper",
                    "y": 0.9999999999999998,
                    "yanchor": "bottom",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=1000",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.15666666666666665,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=100",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.4999999999999999,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "learning_rate=10",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.8433333333333332,
                    "yanchor": "middle",
                    "yref": "paper"
                  }
                ],
                "legend": {
                  "title": {
                    "text": "type"
                  },
                  "tracegroupgap": 0
                },
                "margin": {
                  "t": 60
                },
                "template": {
                  "data": {
                    "bar": [
                      {
                        "error_x": {
                          "color": "#2a3f5f"
                        },
                        "error_y": {
                          "color": "#2a3f5f"
                        },
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "bar"
                      }
                    ],
                    "barpolar": [
                      {
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "barpolar"
                      }
                    ],
                    "carpet": [
                      {
                        "aaxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "baxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "type": "carpet"
                      }
                    ],
                    "choropleth": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "choropleth"
                      }
                    ],
                    "contour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "contour"
                      }
                    ],
                    "contourcarpet": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "contourcarpet"
                      }
                    ],
                    "heatmap": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmap"
                      }
                    ],
                    "heatmapgl": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmapgl"
                      }
                    ],
                    "histogram": [
                      {
                        "marker": {
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "histogram"
                      }
                    ],
                    "histogram2d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2d"
                      }
                    ],
                    "histogram2dcontour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2dcontour"
                      }
                    ],
                    "mesh3d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "mesh3d"
                      }
                    ],
                    "parcoords": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "parcoords"
                      }
                    ],
                    "pie": [
                      {
                        "automargin": true,
                        "type": "pie"
                      }
                    ],
                    "scatter": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter"
                      }
                    ],
                    "scatter3d": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter3d"
                      }
                    ],
                    "scattercarpet": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattercarpet"
                      }
                    ],
                    "scattergeo": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergeo"
                      }
                    ],
                    "scattergl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergl"
                      }
                    ],
                    "scattermapbox": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattermapbox"
                      }
                    ],
                    "scatterpolar": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolar"
                      }
                    ],
                    "scatterpolargl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolargl"
                      }
                    ],
                    "scatterternary": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterternary"
                      }
                    ],
                    "surface": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "surface"
                      }
                    ],
                    "table": [
                      {
                        "cells": {
                          "fill": {
                            "color": "#EBF0F8"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "header": {
                          "fill": {
                            "color": "#C8D4E3"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "type": "table"
                      }
                    ]
                  },
                  "layout": {
                    "annotationdefaults": {
                      "arrowcolor": "#2a3f5f",
                      "arrowhead": 0,
                      "arrowwidth": 1
                    },
                    "autotypenumbers": "strict",
                    "coloraxis": {
                      "colorbar": {
                        "outlinewidth": 0,
                        "ticks": ""
                      }
                    },
                    "colorscale": {
                      "diverging": [
                        [
                          0,
                          "#8e0152"
                        ],
                        [
                          0.1,
                          "#c51b7d"
                        ],
                        [
                          0.2,
                          "#de77ae"
                        ],
                        [
                          0.3,
                          "#f1b6da"
                        ],
                        [
                          0.4,
                          "#fde0ef"
                        ],
                        [
                          0.5,
                          "#f7f7f7"
                        ],
                        [
                          0.6,
                          "#e6f5d0"
                        ],
                        [
                          0.7,
                          "#b8e186"
                        ],
                        [
                          0.8,
                          "#7fbc41"
                        ],
                        [
                          0.9,
                          "#4d9221"
                        ],
                        [
                          1,
                          "#276419"
                        ]
                      ],
                      "sequential": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ],
                      "sequentialminus": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ]
                    },
                    "colorway": [
                      "#636efa",
                      "#EF553B",
                      "#00cc96",
                      "#ab63fa",
                      "#FFA15A",
                      "#19d3f3",
                      "#FF6692",
                      "#B6E880",
                      "#FF97FF",
                      "#FECB52"
                    ],
                    "font": {
                      "color": "#2a3f5f"
                    },
                    "geo": {
                      "bgcolor": "white",
                      "lakecolor": "white",
                      "landcolor": "#E5ECF6",
                      "showlakes": true,
                      "showland": true,
                      "subunitcolor": "white"
                    },
                    "hoverlabel": {
                      "align": "left"
                    },
                    "hovermode": "closest",
                    "mapbox": {
                      "style": "light"
                    },
                    "paper_bgcolor": "white",
                    "plot_bgcolor": "#E5ECF6",
                    "polar": {
                      "angularaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "radialaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "scene": {
                      "xaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "yaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "zaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      }
                    },
                    "shapedefaults": {
                      "line": {
                        "color": "#2a3f5f"
                      }
                    },
                    "ternary": {
                      "aaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "baxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "caxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "title": {
                      "x": 0.05
                    },
                    "xaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    },
                    "yaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    }
                  }
                },
                "width": 500,
                "xaxis": {
                  "anchor": "y",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis2": {
                  "anchor": "y2",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis3": {
                  "anchor": "y3",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "title": {
                    "text": "epoch"
                  }
                },
                "xaxis4": {
                  "anchor": "y4",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis5": {
                  "anchor": "y5",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis6": {
                  "anchor": "y6",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis7": {
                  "anchor": "y7",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis8": {
                  "anchor": "y8",
                  "domain": [
                    0.3333333333333333,
                    0.6466666666666666
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "xaxis9": {
                  "anchor": "y9",
                  "domain": [
                    0.6666666666666666,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "yaxis": {
                  "anchor": "x",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "title": {
                    "text": "accuracy"
                  }
                },
                "yaxis2": {
                  "anchor": "x2",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis3": {
                  "anchor": "x3",
                  "domain": [
                    0,
                    0.3133333333333333
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis4": {
                  "anchor": "x4",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "title": {
                    "text": "accuracy"
                  }
                },
                "yaxis5": {
                  "anchor": "x5",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis6": {
                  "anchor": "x6",
                  "domain": [
                    0.34333333333333327,
                    0.6566666666666665
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis7": {
                  "anchor": "x7",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "title": {
                    "text": "accuracy"
                  }
                },
                "yaxis8": {
                  "anchor": "x8",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "showticklabels": false
                },
                "yaxis9": {
                  "anchor": "x9",
                  "domain": [
                    0.6866666666666665,
                    0.9999999999999998
                  ],
                  "matches": "y",
                  "showticklabels": false
                }
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        }
      ],
      "source": [
        "runs_df = pd.concat(results)\n",
        "\n",
        "# plot training loss and test loss over time\n",
        "px.line(\n",
        "    runs_df,\n",
        "    line_group=\"run_id\",\n",
        "    x=\"epoch\",\n",
        "    y=\"loss\",\n",
        "    color=\"type\",\n",
        "    hover_data=[\"batch_size\", \"learning_rate\", \"dropout_fraction\"],\n",
        "    facet_row=\"learning_rate\",\n",
        "    facet_col=\"batch_size\",\n",
        "    width=500,\n",
        ").show()\n",
        "\n",
        "# plot accuracy over time\n",
        "px.line(\n",
        "    runs_df,\n",
        "    line_group=\"run_id\",\n",
        "    x=\"epoch\",\n",
        "    y=\"accuracy\",\n",
        "    color=\"type\",\n",
        "    hover_data=[\"batch_size\", \"learning_rate\", \"dropout_fraction\"],\n",
        "    facet_row=\"learning_rate\",\n",
        "    facet_col=\"batch_size\",\n",
        "    width=500,\n",
        ").show()\n"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {
        "id": "MiBQDcMPpgkS"
      },
      "source": [
        "## 8. Plot the before & after, showing the results of the best matrix found during training\n",
        "\n",
        "The better the matrix is, the more cleanly it will separate the similar and dissimilar pairs."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "id": "hzjoyLDOpgkS"
      },
      "outputs": [],
      "source": [
        "# apply result of best run to original data\n",
        "best_run = runs_df.sort_values(by=\"accuracy\", ascending=False).iloc[0]\n",
        "best_matrix = best_run[\"matrix\"]\n",
        "apply_matrix_to_embeddings_dataframe(best_matrix, df)\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "id": "nLnvABnXpgkS",
        "outputId": "0c070faa-6e3e-4765-b082-565c72a609be"
      },
      "outputs": [
        {
          "data": {
            "application/vnd.plotly.v1+json": {
              "config": {
                "plotlyServerURL": "https://plot.ly"
              },
              "data": [
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.9267355919345323,
                    0.8959824209230313,
                    0.9119725922265434,
                    0.854066984766886,
                    0.892887342475088,
                    0.9197115510183956,
                    0.8645429616364463,
                    0.8314164166177409,
                    0.7243313891695883,
                    0.8819971504547897,
                    0.7956215051893748,
                    0.7959481856578454,
                    0.8682525487547492,
                    0.8973704561532523,
                    0.8648042605746787,
                    0.9236698981903941,
                    0.9834804742323746,
                    0.8152447410515193,
                    0.8251720025778885,
                    0.8138195587168373,
                    0.8041885875094662,
                    0.9329690881953521,
                    0.9560346908585096,
                    0.9727875559800526,
                    0.8739787488907723,
                    0.8208200937845748,
                    0.7246913159387628,
                    0.9324916305400826,
                    0.8285737172791849,
                    0.8797008558439083,
                    0.820333215489803,
                    0.9370111008629193,
                    0.8983827475968527,
                    0.8312111255338568,
                    0.8164052516323846,
                    0.8908148647559723,
                    0.7466264012908705,
                    0.749651964301876,
                    0.87375582683117,
                    0.7849398161817545,
                    0.8309506403579568,
                    0.9307212180139316,
                    0.8281747408531835,
                    0.9529528469109777,
                    0.7828662080109449,
                    0.887100957550481,
                    0.9000278356226864,
                    0.8805448739521785,
                    0.930239131235984,
                    0.8801954897699975,
                    0.8529206900756547,
                    0.9467797362617967,
                    0.950367690540818,
                    0.7030531843080301,
                    0.8643992401506337,
                    0.8536886651933216,
                    0.9619331104636477,
                    0.9798279220663803,
                    0.8545739729953584,
                    0.8957115039259408,
                    0.8241137169318955,
                    0.8234984837204449,
                    0.8936706246811023,
                    0.8987178417246647,
                    0.9081806514389928,
                    0.9208852073112466,
                    0.8961858080980447,
                    0.8831329491347061,
                    0.9282623092166472,
                    0.8990849228018752,
                    0.8284548395672304,
                    0.8202091311790463,
                    0.8647762700375631,
                    0.840136958242196,
                    0.9887387562448309,
                    0.833342655983456,
                    0.828533112298018,
                    0.9117757392019592,
                    0.8706628948983062,
                    0.9279786042901599,
                    0.7389559895894061,
                    0.8433932035736679,
                    0.9240307526722153,
                    0.950769936374633,
                    0.8586024431762636,
                    0.8685107120156628,
                    0.875535036209879,
                    0.9894909159640456,
                    0.8279650063634755,
                    0.9108703739124493,
                    0.9090161898700972,
                    0.8603952576151226,
                    0.7791958170479588,
                    0.8800175078297319,
                    0.8442387843179446,
                    0.7672266106337577,
                    0.9379753906877134,
                    0.8637536227406404,
                    0.9190295684769377,
                    0.813748744480001,
                    0.9134886375270357,
                    0.8043760086358624,
                    0.8792300492020616,
                    0.8716796296133559,
                    0.8669146731224541,
                    0.7736224659067634,
                    0.9437235456545872,
                    0.905686329875817,
                    0.9534823418409002,
                    0.9150626356433706,
                    0.9409873570135594,
                    0.8111212499450311,
                    0.9171209901271343,
                    0.9126582215550586,
                    0.8337042981493767,
                    0.7317859043960847,
                    0.8444929460069593,
                    0.8561920422842992,
                    0.7765276739806221,
                    0.8526116779814181,
                    0.9178549171995068,
                    0.9238337665547537,
                    0.7218806795387105,
                    0.8180162420894919,
                    0.9687438998025378,
                    0.8354559162714565,
                    0.9146160667909478,
                    0.8082103463995212,
                    0.9563444955040953,
                    0.9066029892990775,
                    0.8485102485675593,
                    0.8154210965438722,
                    0.8860338155558548,
                    0.9280705420457523,
                    0.9835182434488767,
                    0.9653797793986199,
                    0.7815047672209323,
                    0.7156150747063292,
                    0.9256075945804699,
                    0.8135073898727201,
                    0.965501517353668,
                    0.8222681603043882,
                    0.907212186549752,
                    0.8611990749004483,
                    0.9075083701591861,
                    0.9452697087352507,
                    0.8792221638449876,
                    0.9261547230875113,
                    0.8628843081873457,
                    0.7825774684929468,
                    0.826587869384389,
                    0.7399697967202743,
                    0.7855475169419611,
                    0.9111614040328055,
                    0.8871057104665023,
                    0.8824403169032524,
                    0.8618250237448342,
                    0.9787037901590712,
                    0.8066230600465234,
                    0.82769109211187,
                    0.9246432066709112,
                    0.8840853142719706,
                    0.786484352300887,
                    0.9106863314452291,
                    0.9342800776988389,
                    0.8573335079738363,
                    0.7780871062391552,
                    0.7913314665963198,
                    0.8574377397709955,
                    0.9078366114351649,
                    0.7529739278117176,
                    0.8630106564789936,
                    0.9051526759332171,
                    0.7715467442267975,
                    0.894146587858075,
                    0.8095881901541373,
                    0.7733578316514722,
                    0.7600408374028372,
                    0.7819972017869313,
                    0.9003461714649055,
                    0.7428048011790054,
                    0.8645936498599726,
                    0.8158769881622202,
                    0.8338827592994027,
                    0.8272653967731632,
                    0.9017517376316297,
                    0.8480852025321463,
                    0.7970818331146111,
                    0.84837067058732,
                    0.9272909953469497,
                    0.9511439765702141,
                    0.8796630924442385,
                    0.8297595339070903,
                    0.8132311695727563,
                    0.846096509352579,
                    0.8787645385610889,
                    0.8591367322994465,
                    0.8452813265723063,
                    0.708120854745005,
                    0.8769677220320785,
                    0.957621649201,
                    0.7463356296247886,
                    0.8618039385828121,
                    0.9560112462834625,
                    0.8478374736767776,
                    0.769289015775232,
                    0.8456866935685449,
                    0.9014601942765245,
                    0.8816990623836058,
                    0.8836365012367311,
                    0.8078009762348856,
                    0.898471670333294,
                    0.9064470715463968,
                    0.8762712610709371,
                    0.9178852315161201,
                    0.7896235958446132,
                    0.8939345739482307,
                    0.9534018415944343,
                    0.8358882135213556,
                    0.9488657107840998,
                    0.9046799883600772,
                    0.7583576517359092,
                    0.9080459939022663,
                    0.7709722685822517,
                    0.9635512477648485,
                    0.9792712672362888,
                    0.8526700765974843,
                    0.8278133097990042,
                    0.9735858611728696,
                    0.721230194834802,
                    0.8257425311085005,
                    0.9243205490037274,
                    0.9183796453898277,
                    0.9029146937248601,
                    0.9410246048181872,
                    0.9609604036664618,
                    0.7467407974920351,
                    0.8831901217813377,
                    0.8173287200784538,
                    0.8067347032404598,
                    0.7921957436175584,
                    0.9110994793014074,
                    0.8678737504217436,
                    0.9117743220816726,
                    0.781256498099708,
                    0.8553931170417658,
                    0.8798565764815073,
                    0.8485358179238083,
                    0.7748765495278522,
                    0.9432062986428599,
                    0.8328320716129907,
                    0.7983629771463491,
                    0.9345589964322458,
                    0.780034700168418,
                    0.9894680324844076,
                    0.8239308905190431,
                    0.8236003498823307,
                    0.8346101074957768,
                    0.8273793488443836,
                    0.7872103673165679,
                    0.9502897884724039,
                    0.83306630344504,
                    0.9346568240975981,
                    0.8082083566882774,
                    0.8920672687911747,
                    0.8566523137465843,
                    0.7636170858064102,
                    0.8271498146927989,
                    0.8450776675259235,
                    0.9045266249905214,
                    0.8578964053464326,
                    0.8673866119197355,
                    0.8804224181917263,
                    0.8199459545727489,
                    0.932410033971128,
                    0.9096821262750205,
                    0.8658255622247675,
                    0.9386720385226357,
                    0.8517830114204678,
                    0.889433736353776,
                    0.978847593529176,
                    0.8369738178686744,
                    0.8438616304896431,
                    0.9457050132083015,
                    0.8699723446274389,
                    0.7795221418941651,
                    0.9136284834408402,
                    0.839461038218975,
                    0.9453279809479284,
                    0.7899532071809374,
                    0.9078373589637019,
                    0.8434980548175782,
                    0.8112068688921613,
                    0.9466506411385276,
                    0.931413666561485,
                    0.7932453730948084,
                    0.8205411414836395,
                    0.9243834389669592,
                    0.719616209472987,
                    0.7552985082978257,
                    0.9593440978701407,
                    0.917557937781671,
                    0.8643861907339583,
                    0.8315201130646898,
                    0.7608819746989568,
                    0.9704324558544556,
                    0.8037085307118571,
                    0.7785270629127378,
                    0.8044961889638995,
                    0.8313307525316546,
                    0.8064106357955508,
                    0.9291149169587132,
                    0.8412940941318643,
                    0.6917091086045903,
                    0.895204432897799,
                    0.8182250729145697,
                    0.8645847241904496,
                    0.8532020284403578,
                    0.8143634597102418,
                    0.8825747762544934,
                    0.7764652529619696,
                    0.850099367461302,
                    0.8616919096196407,
                    0.9257293995599788,
                    0.935772204341753,
                    0.7742657205171264,
                    0.7898710076766889,
                    0.8590438495382458,
                    0.9317809680608741,
                    0.9087109944408146,
                    0.949297999227784,
                    0.8813316524294974,
                    0.7372081408133433,
                    0.8838176418404169
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.9424796847531719,
                    0.9078956620721231,
                    0.8334324872928349,
                    0.9352180099421901,
                    0.9055462980371385,
                    0.8981939710469806,
                    0.8310153256803275,
                    0.8504676050313356,
                    0.8456281884302819,
                    0.8845204616348977,
                    0.9575409743129409,
                    0.8867362111499236,
                    0.8268148049156373,
                    0.9197424487445726,
                    0.7868932880014918,
                    0.7584994087629051,
                    0.9184151106611779,
                    0.8634069821364065,
                    0.8347803687511831,
                    0.8293627315810226,
                    0.9290633380383377,
                    0.8385821691321573,
                    0.9389267221382397,
                    0.890818442649049,
                    0.86634760562554,
                    0.8406483291061461,
                    0.8084243397427517,
                    0.8909500804242533,
                    0.9262896019507018,
                    0.8955541231110083,
                    0.8055268512037249,
                    0.758607678624549,
                    0.9609058491722305,
                    0.91495905751538,
                    0.8670137150812314,
                    0.8813831602682358,
                    0.8602255150075953,
                    0.9239960998195261,
                    0.91732217769794,
                    0.8037285395908439,
                    0.9196344979660896,
                    0.8179495018382927,
                    0.9015423003762686,
                    0.9054394610623084,
                    0.9309412941515826,
                    0.9421722892808463,
                    0.7632823176731237,
                    0.8622055676377434,
                    0.9855273108710355,
                    0.9144155566597573,
                    0.916057392559463,
                    0.8027504546689012,
                    0.7131090055200247,
                    0.86174194818737,
                    0.982873171115144,
                    0.8100227539053508,
                    0.8923878603823862,
                    0.8096643432711604,
                    0.8707613707684018,
                    0.8786740148818871,
                    0.8274639911687568,
                    0.8927098767487393,
                    0.9565597075186062,
                    0.9060728095743347,
                    0.738307517030644,
                    0.9645943657917937,
                    0.8755564011033787,
                    0.879644342330288,
                    0.8679709669180851,
                    0.9304235148160597,
                    0.8902804960376421,
                    0.8748369557157689,
                    0.9999999999999999,
                    0.7979398167195422,
                    0.8182553472910762,
                    0.7782108671759803,
                    0.8427610550014142,
                    0.8696408842646327,
                    0.8747903026537825,
                    0.914973368517656,
                    0.9651568968718233,
                    0.9775547988384379,
                    0.8964005885715726,
                    0.8689760349348122,
                    0.8501707274497268,
                    0.9069421089316828,
                    0.7682621587597694,
                    0.9658683147667629,
                    0.8946443487011166,
                    0.7855154267442119,
                    0.8963791544804274,
                    0.8062904921192978,
                    0.8165205978948239,
                    0.8392522254625653,
                    0.9456080863923884,
                    0.7904904126133966,
                    0.833126792097794,
                    0.7852156051245299,
                    0.7859162398886373,
                    0.9097674992318959,
                    0.8868692153350769,
                    0.9391826646753169,
                    0.9428151202937937,
                    0.7923603877885725,
                    0.9018727189911658,
                    0.9723161942344292,
                    0.7820369113228325,
                    0.9667234201162873,
                    0.978769627389609,
                    0.9155729781931277,
                    0.8273013970664075,
                    0.9603319621485501,
                    0.9298975009081959,
                    0.8775117467919693,
                    0.8614509568162568,
                    0.9144155657624043,
                    0.7783710792186642,
                    0.9701880190033496,
                    0.7858944693777298,
                    0.9278353490304795,
                    0.9472367444800102,
                    0.7834809788888359,
                    0.7997358978342995,
                    0.8459052935830644,
                    0.8612076995259889,
                    0.8470901722260982,
                    0.824037271004761,
                    0.865608650068826,
                    0.8023193246081538,
                    0.7836788857151791,
                    0.880404135119559,
                    0.8491559249509729,
                    0.788345270395367,
                    0.9461393747813323,
                    0.835123385080455,
                    0.8158174048388718,
                    0.8604581300972295,
                    0.9623616555004862,
                    0.8564688397784819,
                    0.8576867681204893,
                    0.8973905356978807,
                    0.8634447095761868,
                    0.8149528594606367,
                    0.873171253801101,
                    0.8653347676818378,
                    0.929525558735665,
                    0.8358267202818717,
                    0.971888682386554,
                    0.8500189240129448,
                    0.6201715858790215,
                    0.8982737437906866,
                    0.8919523978597029,
                    0.7327218620224804,
                    0.8329671225228632,
                    0.9265589851585685,
                    0.8976605724257473,
                    0.8865148832512937,
                    0.7893917264917192,
                    0.7303107673595587,
                    0.8428958487387793,
                    0.8712646524042454,
                    0.9726111208329614,
                    0.9368020234351375,
                    0.9270010843622818,
                    0.8900608737128451,
                    0.7975173141451626,
                    0.9403308743666237,
                    0.8484005148509558,
                    0.9285585494125882,
                    0.8461714640960527,
                    0.9301612552439464,
                    0.9840391344904358,
                    0.8305503032909712,
                    0.8985536902480058,
                    0.9476344055766088,
                    0.9342892661789283,
                    0.8849523247638441,
                    0.7736620851030366,
                    0.8083290901088768,
                    0.9510007696957726,
                    0.8677438102591386,
                    0.8324233959261911,
                    0.7379868650067231,
                    0.9049462205283083,
                    0.9044068964151009,
                    0.7810399099399672,
                    0.9040280419401343,
                    0.7720832557964016,
                    0.7168259249496903,
                    0.8657076231674912,
                    0.9689982290529879,
                    0.9330371348632155,
                    0.7014093149115691,
                    0.9056081832768474,
                    0.8483474394912361,
                    0.8729108893663646,
                    0.8494252835407102,
                    0.8702668029663239,
                    0.8703072652243842,
                    0.9279473628052111,
                    0.8615930026010632,
                    0.7590822846968434,
                    0.8435232136599701,
                    0.8264379743713927,
                    0.8793126202650794,
                    0.847452301256391,
                    0.7546334370590053,
                    0.8870818568791698,
                    0.8349553719854912,
                    0.9232007589383142,
                    0.7924421895458662,
                    0.8556103146657943,
                    0.8397958720336148,
                    0.9358165878534705,
                    0.904577353436614,
                    0.9022537114781464,
                    0.775603917181226,
                    0.946091618927627,
                    0.8264119461162666,
                    0.8261258105029201,
                    0.8605336598142989,
                    0.7518422489499549,
                    0.849587557879856,
                    0.9922578397415717,
                    0.7499254104864422,
                    0.8845204616348977,
                    0.8361936554693772,
                    0.9172228808488442,
                    0.8068135566092208,
                    0.795739929008372,
                    0.8632611464779958,
                    0.7612462576141025,
                    0.9589125421073597,
                    0.9555759038945358,
                    0.8822980105025566,
                    0.9663740139243834,
                    0.9071760951442449,
                    0.9335338894977118,
                    0.8042262160785201,
                    0.9399607295068667,
                    0.8318513711395181,
                    0.8697471272873054,
                    0.9103391819002411,
                    0.8272582065159091,
                    0.7868989539137853,
                    0.7416168920325092,
                    0.8828593510646834,
                    0.9141342991325323,
                    0.7259887492833588,
                    0.9478299721997572,
                    0.843766518385904,
                    0.919830425538435,
                    0.9069062941852448,
                    0.9036466185261808,
                    0.9817542893707696,
                    0.8833292621745382,
                    0.832556614533968,
                    0.8135910443631924,
                    0.9628932969024508,
                    0.9450804655496136,
                    0.9226384091529121,
                    0.8401818103049188,
                    0.723691406760321,
                    0.6828741135249211,
                    0.834410523069395,
                    0.9959256404542386,
                    0.952870396433041,
                    0.969514692554192,
                    0.9220387806044666,
                    0.9511950116111251,
                    0.87442203104197,
                    0.8399026046246612,
                    0.9029483760650204,
                    0.9097073428917352,
                    0.8651925582004045,
                    0.9178332691819033,
                    0.7556713752294848,
                    0.8601740894614596,
                    0.8250804250840363,
                    0.799473306929639,
                    0.8911389639861809,
                    0.915913776235107,
                    0.7867422041389165,
                    0.8035116695233039,
                    0.7702882636946234,
                    0.9060460430333088,
                    0.7214029229730072,
                    0.8607904806397634,
                    0.8228468643082103,
                    0.8900020169140401,
                    0.9343567736626528,
                    0.9305049279291139,
                    0.9664193138643489,
                    0.9008537853184969,
                    0.7625840742620333,
                    0.815302054727336,
                    0.9215061720798934,
                    0.7192673801865671,
                    0.8949994067748966,
                    0.9367566547265034,
                    0.7602684166275758,
                    0.8184439767612992,
                    0.8361983856596491,
                    0.7761725471827079,
                    0.7724780968772909,
                    0.9249211346782868,
                    0.8718843131924867,
                    0.8522890335712519,
                    0.9015475867709434,
                    0.8720699810318118,
                    0.8937599387455695,
                    0.8721713573852221,
                    0.8100783166142076,
                    1.0000000000000002,
                    0.8213222537973748,
                    0.8361185401136565,
                    0.8371907459006128,
                    0.9065697385076582,
                    0.752240671472798,
                    0.8283078905766531,
                    0.8499886819287953,
                    0.9097932369637356,
                    0.9529813104528191,
                    0.8449289750214674,
                    1,
                    0.8302949362084788,
                    0.7741532046500113,
                    0.8743828037305432,
                    0.8201855611163867,
                    0.8194689758101458,
                    0.7925076796225758,
                    0.8748126117575765,
                    0.8299510305557958,
                    0.9619426561868236,
                    0.8627070029199212
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.7299945446332757,
                    0.761829365033925,
                    0.676235270353631,
                    0.7023016593603112,
                    0.7350156032869306,
                    0.7735236691667362,
                    0.7187241641280292,
                    0.8063818744486255,
                    0.7115273138749911,
                    0.7402039545462252,
                    0.7372226375565185,
                    0.719606336264161,
                    0.8344052997670786,
                    0.6881482918877712,
                    0.650513335277676,
                    0.7182967436895931,
                    0.7659970793010594,
                    0.6361476177437694,
                    0.7957002158983555,
                    0.7395402073375513,
                    0.7614384854511241,
                    0.6835300790002988,
                    0.6209194558826169,
                    0.7907726220786139,
                    0.7215750502680501,
                    0.7271947538454276,
                    0.6962333614625641,
                    0.7517476038206041,
                    0.7135529871009506,
                    0.7522919529124952,
                    0.7120639628754573,
                    0.7623014780353815,
                    0.7939574492876662,
                    0.6998873138766412,
                    0.7700594720774577,
                    0.7766161530342343,
                    0.7285080945118152,
                    0.7562511166739386,
                    0.8086622737220109,
                    0.7565297004511734,
                    0.7315242427462245,
                    0.791225232078265,
                    0.7281092467134649,
                    0.7675685917886341,
                    0.7122436997016454,
                    0.7600255476588682,
                    0.769465992253443,
                    0.7552507233110458,
                    0.7373719614604455,
                    0.7681449827665134,
                    0.7760194768221379,
                    0.8035677492769473,
                    0.7771752906080505,
                    0.6514739683312004,
                    0.744787832649546,
                    0.7700232252518491,
                    0.6901772464238046,
                    0.721128796446804,
                    0.686814078974691,
                    0.796786900996621,
                    0.8242176320872988,
                    0.7806742901384496,
                    0.7697696656361902,
                    0.7868668497422822,
                    0.7304548331410279,
                    0.7296767182508448,
                    0.699401040331297,
                    0.8332660282838579,
                    0.6513224421864793,
                    0.6927364371405096,
                    0.7491793002300279,
                    0.7909034218879171,
                    0.7754176150761527,
                    0.8004827006684735,
                    0.6659923526948868,
                    0.8129618884980553,
                    0.7496476113488948,
                    0.6519665061955423,
                    0.7319506042291191,
                    0.7367099004846358,
                    0.8590817275589286,
                    0.6684490623528697,
                    0.7469140136676841,
                    0.7269016830127544,
                    0.6834261236240542,
                    0.6390380393123325,
                    0.7129209978685723,
                    0.6879274953497815,
                    0.6720427402498903,
                    0.7343923261152341,
                    0.5977941468082364,
                    0.7574521074337901,
                    0.7302129263006347,
                    0.7380209108764002,
                    0.7280177437983031,
                    0.6870880300968067,
                    0.6928024686173956,
                    0.6403900073451696,
                    0.7209247906698092,
                    0.666799070142464,
                    0.7233569397428488,
                    0.7555267048881131,
                    0.7275931437975757,
                    0.7562785127332186,
                    0.736649021802634,
                    0.762569980781232,
                    0.7741923768467133,
                    0.6669773662198453,
                    0.7499779113599104,
                    0.7783371410262362,
                    0.7798574909618832,
                    0.7068277719647446,
                    0.7718066533969561,
                    0.7078874155127759,
                    0.7238814624322412,
                    0.8729026036201937,
                    0.7106759057210158,
                    0.7585767440008555,
                    0.6923565683937588,
                    0.6693898561996529,
                    0.7219944003835542,
                    0.7188064794656569,
                    0.7491451951131076,
                    0.6750197249394477,
                    0.7009868838756784,
                    0.6519589230722103,
                    0.775109860666869,
                    0.6682014813660948,
                    0.6618358724923147,
                    0.6218362070243146,
                    0.7518134154555007,
                    0.7571307427362168,
                    0.7546823360234822,
                    0.7416435744760188,
                    0.7676464608286054,
                    0.5799855321635428,
                    0.796191792361925,
                    0.6845373357552841,
                    0.7667984177770908,
                    0.7393609881148844,
                    0.7544580057165962,
                    0.7397495323890664,
                    0.7658298386950907,
                    0.6611125314078552,
                    0.7977728876184589,
                    0.831090991824215,
                    0.6982347705041605,
                    0.6312226759947304,
                    0.6482907496231076,
                    0.7658362595299695,
                    0.7518400125995974,
                    0.8025480920087656,
                    0.7461501494015976,
                    0.7239149641844659,
                    0.7090055516721123,
                    0.6846622430587006,
                    0.7112212118684095,
                    0.6794460508450644,
                    0.8344462829585355,
                    0.7638782230495382,
                    0.6558255308015893,
                    0.6799836520005059,
                    0.7148088830660722,
                    0.7658007439571484,
                    0.6581857665503666,
                    0.648734015773256,
                    0.8156725527309459,
                    0.6929202590284855,
                    0.7490919589112273,
                    0.7090723359022927,
                    0.7105572886194162,
                    0.7461374422467866,
                    0.7084742440808511,
                    0.6889378818898818,
                    0.7551565238112727,
                    0.7198789279593328,
                    0.7270482774940944,
                    0.6971190427893721,
                    0.7391610904601401,
                    0.6344734499604212,
                    0.6719507302378157,
                    0.6861159059531602,
                    0.6516118314317232,
                    0.7199095105818035,
                    0.6817881072485698,
                    0.7207373940156253,
                    0.7745467156569463,
                    0.6258059783246434,
                    0.7481056513643776,
                    0.7183327989261993,
                    0.7624705969064652,
                    0.703717229048042,
                    0.7487365873620485,
                    0.7495555007485976,
                    0.6409624588579408,
                    0.7176245775200687,
                    0.7537100520717849,
                    0.7569868075002099,
                    0.7392135510982174,
                    0.7188471960732984,
                    0.697413550280765,
                    0.7138614496887067,
                    0.7057641350396069,
                    0.7675079665142936,
                    0.7310427541091186,
                    0.7808018818735418,
                    0.7567255895826445,
                    0.7035962262766099,
                    0.7040750813384501,
                    0.8183159919038065,
                    0.7953911933398697,
                    0.7464891038038547,
                    0.6751591827598264,
                    0.7849943676377955,
                    0.7155963442284841,
                    0.7428993249606122,
                    0.7131100645054201,
                    0.7227595311429733,
                    0.6519548345531954,
                    0.7201522118536183,
                    0.86540799716664,
                    0.8128819241371503,
                    0.7278912446692862,
                    0.7305867175950502,
                    0.7171875192153516,
                    0.6755179003509543,
                    0.7256221359402913,
                    0.7003814947129137,
                    0.7486334697158199,
                    0.7232489166529666,
                    0.7347697330652992,
                    0.6493837702986368,
                    0.6454310256268904,
                    0.7085305859966166,
                    0.7709963397003181,
                    0.7628122486461532,
                    0.7260869667056613,
                    0.7656074896314918,
                    0.7309944223135776,
                    0.7575162043117213,
                    0.7425954755181217,
                    0.7978452334414571,
                    0.7414129597626153,
                    0.7369987033441427,
                    0.7249664966482501,
                    0.663939118162477,
                    0.7490232329485363,
                    0.7532303509685747,
                    0.6505502824396713,
                    0.6820403873171862,
                    0.7458589415356044,
                    0.65106761846338,
                    0.8190794585362886,
                    0.6404595320063431,
                    0.7620011212588133,
                    0.6793344580417779,
                    0.7470455239529016,
                    0.6254743025101126,
                    0.714021296346165,
                    0.7624376857959331,
                    0.6124088443110871,
                    0.7190909082546953,
                    0.7667977482791689,
                    0.7390282391537781,
                    0.731411776312689,
                    0.6910654011387792,
                    0.7669731150300952,
                    0.7473599833172766,
                    0.7757742306046117,
                    0.7524096654164605,
                    0.668078105815945,
                    0.6797298899209671,
                    0.734572374544983,
                    0.7851097143298676,
                    0.7342031538272321,
                    0.7372538892798539,
                    0.7335209046034747,
                    0.6838366965805461,
                    0.6892908218001774,
                    0.7368799079039347,
                    0.667817778038092,
                    0.7436021083467266,
                    0.643687287978847,
                    0.6459012534104801,
                    0.717961524900317,
                    0.651811280493696,
                    0.774663511332172,
                    0.6481450536649574,
                    0.7135017154296592,
                    0.718243043672562,
                    0.6840974559559992,
                    0.6237039667424505,
                    0.7511301324631847,
                    0.6731554748777204,
                    0.7311433676647258,
                    0.7407740108803432,
                    0.7219713524118947,
                    0.6945284597258313,
                    0.7403964086997485,
                    0.7281416758951161,
                    0.767616849591671,
                    0.7623013461443361,
                    0.794278871148869,
                    0.7094595086377886,
                    0.6363511838820268,
                    0.6955392554405656,
                    0.7448982185809608,
                    0.7328610107493276,
                    0.7208624134565648,
                    0.6762963031433926,
                    0.7755406771546928,
                    0.7045876515082797,
                    0.6244682388287953,
                    0.6742169930443296,
                    0.8182351707309709,
                    0.7329206562104693,
                    0.7750198586773644,
                    0.7686712995024062,
                    0.7382706738683358,
                    0.6670365140953741,
                    0.7122098843228497,
                    0.720363344417827,
                    0.7260325476617868,
                    0.7455849615803621,
                    0.7135971488970759,
                    0.7597698332957011,
                    0.7261113201174606,
                    0.7802411718313292,
                    0.6937507195359754,
                    0.7842882648198792,
                    0.6900501446041246,
                    0.760860218311663,
                    0.7134088358525988,
                    0.7053629634417926
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.698216609902851,
                    0.7175312484264973,
                    0.7186004945024,
                    0.8000305631283213,
                    0.6982596730429885,
                    0.7632305498724974,
                    0.7138465594512551,
                    0.6788567152998355,
                    0.7373203069430523,
                    0.6873036529619025,
                    0.6503465541274354,
                    0.7365009792588864,
                    0.741093678604772,
                    0.735107851927508,
                    0.7789882788330387,
                    0.6997359150722298,
                    0.7996799028712657,
                    0.7467428244049107,
                    0.6687862526227949,
                    0.743064589979991,
                    0.8567601871608912,
                    0.7518934598338161,
                    0.7026922402796603,
                    0.7211365080382613,
                    0.706354045904417,
                    0.779427003097291,
                    0.7362962393365906,
                    0.7291751132805836,
                    0.7122378769301289,
                    0.7140477948952889,
                    0.7173405960056081,
                    0.7634875929311733,
                    0.7977581390590794,
                    0.7301463738402032,
                    0.7615332057968679,
                    0.682227985030492,
                    0.7334635989895589,
                    0.7386028577038499,
                    0.659084291835728,
                    0.7899820114980755,
                    0.7247172517704278,
                    0.7438155210487466,
                    0.7005346860296618,
                    0.6648553956533266,
                    0.7701566644966253,
                    0.7514961574415904,
                    0.7587991656983686,
                    0.7001882273133521,
                    0.6910707516646375,
                    0.7394355361240693,
                    0.7276824899179835,
                    0.6759744362016779,
                    0.8302185470787163,
                    0.6928641094502374,
                    0.7120538723839331,
                    0.7224960785372221,
                    0.7198045816069757,
                    0.6813558762031737,
                    0.7165801628045559,
                    0.7120832723046919,
                    0.8420619371167495,
                    0.946387336435492,
                    0.7554566849916029,
                    0.6539169025880401,
                    0.7809846972343978,
                    0.7724403150471626,
                    0.7005276086814838,
                    0.6393757830582598,
                    0.8206678516495857,
                    0.7220887623700465,
                    0.6457268309147112,
                    0.7355556783165726,
                    0.7154485704709247,
                    0.736485552747626,
                    0.6279868336962336,
                    0.6826307018159569,
                    0.6893086023402188,
                    0.6662701662224271,
                    0.7867533724923507,
                    0.7767747518359883,
                    0.8265509786609969,
                    0.7191298707058883,
                    0.7022356632617615,
                    0.7327905404619434,
                    0.744068997548466,
                    0.7610196098150734,
                    0.7115456073990181,
                    0.7432332956176703,
                    0.7893785433034154,
                    0.7290401089931655,
                    0.6577253300110991,
                    0.7003577024217508,
                    0.6656632326716906,
                    0.777774068734494,
                    0.7332058736487923,
                    0.6832255453323396,
                    0.7959026283113317,
                    0.7447573588360193,
                    0.6641622569109379,
                    0.6536973982676278,
                    0.6665423422110455,
                    0.6839542281076658,
                    0.7423443203765382,
                    0.7548386221527327,
                    0.6027110439271006,
                    0.6910860362919979,
                    0.6562817652314046,
                    0.716178983425107,
                    0.724403115068019,
                    0.7259909803954812,
                    0.7571530348849562,
                    0.733442357026098,
                    0.7422263713152575,
                    0.7194490356714829,
                    0.8324274302968815,
                    0.7326854438693811,
                    0.6965534375434462,
                    0.6112368704742529,
                    0.7304384905441154,
                    0.7720040972080942,
                    0.6935032899149339,
                    0.6744196671055273,
                    0.7147585872166342,
                    0.7752945283086001,
                    0.7247252468203336,
                    0.7487128692844189,
                    0.7073695766484023,
                    0.7770002807438351,
                    0.6765039364068596,
                    0.6337555851438214,
                    0.7894145395726685,
                    0.808127657824718,
                    0.6806212157836168,
                    0.6631556694778886,
                    0.7026864982467919,
                    0.6616453779258454,
                    0.7884566439105933,
                    0.7066759939380831,
                    0.6337579382878964,
                    0.6487574469000645,
                    0.772145520185077,
                    0.7232462166936153,
                    0.7901145921506159,
                    0.6972515066616096,
                    0.7122390390469722,
                    0.7192738985202007,
                    0.6592857746430886,
                    0.7609468966795828,
                    0.6603956483288737,
                    0.7457664593808938,
                    0.6826889303879785,
                    0.7934081087742163,
                    0.6763383894378445,
                    0.6839731303810913,
                    0.6687692061586001,
                    0.6991702352595418,
                    0.7797875942740597,
                    0.7663428993137384,
                    0.7674609664886143,
                    0.7252797254797089,
                    0.6976652869077468,
                    0.7158816756428659,
                    0.718241456712601,
                    0.7632800029357677,
                    0.7220393239650442,
                    0.7734594213360495,
                    0.6831354341128338,
                    0.8050740585092945,
                    0.8304016729411579,
                    0.6942297180429666,
                    0.7777210539532041,
                    0.7971492766742209,
                    0.7551377672686813,
                    0.756111346085797,
                    0.7377139073001521,
                    0.7292914327831885,
                    0.6619852966467218,
                    0.6633387313486273,
                    0.7753943622795527,
                    0.6931366760284077,
                    0.6879938635549627,
                    0.7934277451864097,
                    0.7095961561996172,
                    0.6721647904665006,
                    0.6639348167332335,
                    0.7190874751718003,
                    0.6487682662731885,
                    0.7712237228183192,
                    0.7195541308325332,
                    0.7624245070018526,
                    0.7066568592895756,
                    0.6955819267956074,
                    0.762644689565747,
                    0.696550099953965,
                    0.72160309057877,
                    0.6589853583691466,
                    0.7781076723886485,
                    0.7844353288099747,
                    0.6499942545071061,
                    0.7586643115109818,
                    0.7851245713510661,
                    0.6825431110733673,
                    0.7920473550917971,
                    0.7505505200683399,
                    0.7112992195413225,
                    0.6872424297540068,
                    0.6629403755138552,
                    0.7754417757832819,
                    0.7445419843378304,
                    0.7064660255551196,
                    0.7102764764345906,
                    0.7166584523700176,
                    0.745706231193076,
                    0.7628022956052315,
                    0.6960882904963708,
                    0.6837468719098787,
                    0.7520487263229376,
                    0.7604129787986132,
                    0.7522054011542562,
                    0.6898973175025894,
                    0.712053903439596,
                    0.7339122116041967,
                    0.6789458999380491,
                    0.763449576758855,
                    0.7406926320689465,
                    0.6976029213628422,
                    0.7734670110437211,
                    0.7670042286239444,
                    0.7194794729796166,
                    0.8039337600845294,
                    0.822887238017454,
                    0.7546355748553022,
                    0.712175872919759,
                    0.6283408438534046,
                    0.7277722488820856,
                    0.7848289730316366,
                    0.6470175548703326,
                    0.7175970646556328,
                    0.6982508276209234,
                    0.6931585425755658,
                    0.7105706262503181,
                    0.6541269887120206,
                    0.8404400660965669,
                    0.7278163563567496,
                    0.8022018950050529,
                    0.7449136144274442,
                    0.7254549036563473,
                    0.7708530061010334,
                    0.7226568414320923,
                    0.7174427406313559,
                    0.7242867487973796,
                    0.7465548093208114,
                    0.6481248100208876,
                    0.715420475795481,
                    0.7818995378198694,
                    0.710140386310007,
                    0.779070581516581,
                    0.7022765286826387,
                    0.7920014977874291,
                    0.7028249663680173,
                    0.7464011128583546,
                    0.6874796697586977,
                    0.7834259382246206,
                    0.7487992683674399,
                    0.6256280566008573,
                    0.7248416704075088,
                    0.6787298488859722,
                    0.7604099689852636,
                    0.7563309422531382,
                    0.6536692254847523,
                    0.7277402265583088,
                    0.7961681595257355,
                    0.7183023940359037,
                    0.8578324194628058,
                    0.6890352710148969,
                    0.711616174722821,
                    0.6560825239577808,
                    0.7235723022023411,
                    0.6649236563739442,
                    0.6800589793852263,
                    0.7785177771732936,
                    0.8277144895457349,
                    0.7047472917613488,
                    0.6981993581133783,
                    0.69214194925211,
                    0.7175225477364914,
                    0.6821700037384272,
                    0.6934882153010823,
                    0.6671459724713757,
                    0.7577056235720239,
                    0.7452347481085622,
                    0.7540647847248615,
                    0.7623045528688165,
                    0.8419961516314336,
                    0.7607631404114222,
                    0.7104592749279437,
                    0.7907219223857475,
                    0.6626370861321504,
                    0.7293323972767943,
                    0.6747790790615374,
                    0.7205393564241827,
                    0.7182141925188444,
                    0.6698620455142402,
                    0.7774615433926413,
                    0.68441903533162,
                    0.7195176784475413,
                    0.7765542578440102,
                    0.7653003235671018,
                    0.6588957811563716,
                    0.7049538466814345,
                    0.6767019827253833,
                    0.6852115350974048,
                    0.7159808946533304,
                    0.6275008181698795,
                    0.6641598464064111,
                    0.7653064009797307,
                    0.7846062245731015,
                    0.7131195190890488,
                    0.7388407888274415,
                    0.7078575690975646,
                    0.7922969773693673,
                    0.6399205949452071,
                    0.7522331808600956,
                    0.756127025845852,
                    0.7527950868321375,
                    0.7791392496558245,
                    0.7388745760013306,
                    0.6739605493576779,
                    0.6432673241279167,
                    0.7124181751534668,
                    0.669456709871883,
                    0.7067049471522552,
                    0.6685698115209102,
                    0.7430777123010961,
                    0.7510627360284545
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                }
              ],
              "layout": {
                "annotations": [
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=test",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.2425,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=train",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.7575000000000001,
                    "yanchor": "middle",
                    "yref": "paper"
                  }
                ],
                "barmode": "overlay",
                "legend": {
                  "title": {
                    "text": "label"
                  },
                  "tracegroupgap": 0
                },
                "margin": {
                  "t": 60
                },
                "template": {
                  "data": {
                    "bar": [
                      {
                        "error_x": {
                          "color": "#2a3f5f"
                        },
                        "error_y": {
                          "color": "#2a3f5f"
                        },
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "bar"
                      }
                    ],
                    "barpolar": [
                      {
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "barpolar"
                      }
                    ],
                    "carpet": [
                      {
                        "aaxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "baxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "type": "carpet"
                      }
                    ],
                    "choropleth": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "choropleth"
                      }
                    ],
                    "contour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "contour"
                      }
                    ],
                    "contourcarpet": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "contourcarpet"
                      }
                    ],
                    "heatmap": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmap"
                      }
                    ],
                    "heatmapgl": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmapgl"
                      }
                    ],
                    "histogram": [
                      {
                        "marker": {
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "histogram"
                      }
                    ],
                    "histogram2d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2d"
                      }
                    ],
                    "histogram2dcontour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2dcontour"
                      }
                    ],
                    "mesh3d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "mesh3d"
                      }
                    ],
                    "parcoords": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "parcoords"
                      }
                    ],
                    "pie": [
                      {
                        "automargin": true,
                        "type": "pie"
                      }
                    ],
                    "scatter": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter"
                      }
                    ],
                    "scatter3d": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter3d"
                      }
                    ],
                    "scattercarpet": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattercarpet"
                      }
                    ],
                    "scattergeo": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergeo"
                      }
                    ],
                    "scattergl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergl"
                      }
                    ],
                    "scattermapbox": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattermapbox"
                      }
                    ],
                    "scatterpolar": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolar"
                      }
                    ],
                    "scatterpolargl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolargl"
                      }
                    ],
                    "scatterternary": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterternary"
                      }
                    ],
                    "surface": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "surface"
                      }
                    ],
                    "table": [
                      {
                        "cells": {
                          "fill": {
                            "color": "#EBF0F8"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "header": {
                          "fill": {
                            "color": "#C8D4E3"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "type": "table"
                      }
                    ]
                  },
                  "layout": {
                    "annotationdefaults": {
                      "arrowcolor": "#2a3f5f",
                      "arrowhead": 0,
                      "arrowwidth": 1
                    },
                    "autotypenumbers": "strict",
                    "coloraxis": {
                      "colorbar": {
                        "outlinewidth": 0,
                        "ticks": ""
                      }
                    },
                    "colorscale": {
                      "diverging": [
                        [
                          0,
                          "#8e0152"
                        ],
                        [
                          0.1,
                          "#c51b7d"
                        ],
                        [
                          0.2,
                          "#de77ae"
                        ],
                        [
                          0.3,
                          "#f1b6da"
                        ],
                        [
                          0.4,
                          "#fde0ef"
                        ],
                        [
                          0.5,
                          "#f7f7f7"
                        ],
                        [
                          0.6,
                          "#e6f5d0"
                        ],
                        [
                          0.7,
                          "#b8e186"
                        ],
                        [
                          0.8,
                          "#7fbc41"
                        ],
                        [
                          0.9,
                          "#4d9221"
                        ],
                        [
                          1,
                          "#276419"
                        ]
                      ],
                      "sequential": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ],
                      "sequentialminus": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ]
                    },
                    "colorway": [
                      "#636efa",
                      "#EF553B",
                      "#00cc96",
                      "#ab63fa",
                      "#FFA15A",
                      "#19d3f3",
                      "#FF6692",
                      "#B6E880",
                      "#FF97FF",
                      "#FECB52"
                    ],
                    "font": {
                      "color": "#2a3f5f"
                    },
                    "geo": {
                      "bgcolor": "white",
                      "lakecolor": "white",
                      "landcolor": "#E5ECF6",
                      "showlakes": true,
                      "showland": true,
                      "subunitcolor": "white"
                    },
                    "hoverlabel": {
                      "align": "left"
                    },
                    "hovermode": "closest",
                    "mapbox": {
                      "style": "light"
                    },
                    "paper_bgcolor": "white",
                    "plot_bgcolor": "#E5ECF6",
                    "polar": {
                      "angularaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "radialaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "scene": {
                      "xaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "yaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "zaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      }
                    },
                    "shapedefaults": {
                      "line": {
                        "color": "#2a3f5f"
                      }
                    },
                    "ternary": {
                      "aaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "baxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "caxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "title": {
                      "x": 0.05
                    },
                    "xaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    },
                    "yaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    }
                  }
                },
                "width": 500,
                "xaxis": {
                  "anchor": "y",
                  "domain": [
                    0,
                    0.98
                  ],
                  "title": {
                    "text": "cosine_similarity"
                  }
                },
                "xaxis2": {
                  "anchor": "y2",
                  "domain": [
                    0,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "yaxis": {
                  "anchor": "x",
                  "domain": [
                    0,
                    0.485
                  ],
                  "title": {
                    "text": "count"
                  }
                },
                "yaxis2": {
                  "anchor": "x2",
                  "domain": [
                    0.515,
                    1
                  ],
                  "matches": "y",
                  "title": {
                    "text": "count"
                  }
                }
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test accuracy: 88.8% ± 2.4%\n"
          ]
        },
        {
          "data": {
            "application/vnd.plotly.v1+json": {
              "config": {
                "plotlyServerURL": "https://plot.ly"
              },
              "data": [
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=train<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.7582729,
                    0.77073956,
                    0.65517455,
                    0.5354905,
                    0.7025479,
                    0.709722,
                    0.5727541,
                    0.412897,
                    0.30631626,
                    0.7510688,
                    0.4988944,
                    0.5430158,
                    0.6839247,
                    0.6838065,
                    0.73403287,
                    0.7992136,
                    0.94956607,
                    0.6076832,
                    0.6175979,
                    0.4301836,
                    0.41434872,
                    0.73805547,
                    0.8479354,
                    0.92995274,
                    0.6952341,
                    0.48424515,
                    0.37089303,
                    0.77779347,
                    0.46015453,
                    0.64896333,
                    0.5164182,
                    0.7807624,
                    0.744742,
                    0.60417265,
                    0.59455043,
                    0.69891477,
                    0.37484196,
                    0.42779225,
                    0.6438146,
                    0.5732102,
                    0.6928373,
                    0.8111064,
                    0.63694704,
                    0.8399712,
                    0.53589267,
                    0.72180825,
                    0.70790404,
                    0.6756934,
                    0.7727933,
                    0.7064891,
                    0.64490205,
                    0.8398692,
                    0.87921053,
                    0.27489004,
                    0.6942742,
                    0.57597476,
                    0.8717559,
                    0.9125966,
                    0.5301582,
                    0.60662705,
                    0.57393736,
                    0.6068608,
                    0.7286386,
                    0.70121765,
                    0.6917445,
                    0.8071895,
                    0.67622876,
                    0.79283214,
                    0.7993787,
                    0.72133046,
                    0.50220287,
                    0.57344776,
                    0.55474436,
                    0.6362008,
                    0.95821553,
                    0.49900997,
                    0.57039934,
                    0.733116,
                    0.78354275,
                    0.78056777,
                    0.05022569,
                    0.5561661,
                    0.77611494,
                    0.8746339,
                    0.66644144,
                    0.7035711,
                    0.6177985,
                    0.9615531,
                    0.54877394,
                    0.7237992,
                    0.80131793,
                    0.6017557,
                    0.5219256,
                    0.75939846,
                    0.5138731,
                    0.47525474,
                    0.8088046,
                    0.55364305,
                    0.8155018,
                    0.64197356,
                    0.78343874,
                    0.48175102,
                    0.7740524,
                    0.5643503,
                    0.5785323,
                    0.4045787,
                    0.844991,
                    0.78484446,
                    0.83895594,
                    0.779593,
                    0.8160813,
                    0.501844,
                    0.72787094,
                    0.74717385,
                    0.56744146,
                    0.23557578,
                    0.508822,
                    0.5945875,
                    0.30390128,
                    0.6806088,
                    0.73289305,
                    0.7714907,
                    0.35027343,
                    0.62142855,
                    0.89361686,
                    0.6107199,
                    0.80126244,
                    0.36980844,
                    0.8348204,
                    0.69843924,
                    0.40595374,
                    0.60189885,
                    0.67977715,
                    0.8707584,
                    0.94712615,
                    0.84084773,
                    0.38264382,
                    0.2997883,
                    0.82825613,
                    0.62465394,
                    0.8830344,
                    0.476409,
                    0.8145352,
                    0.5575372,
                    0.5749161,
                    0.8105687,
                    0.60813963,
                    0.8235246,
                    0.59305763,
                    0.35663584,
                    0.62649107,
                    0.3989369,
                    0.5393962,
                    0.85537386,
                    0.59931004,
                    0.7605111,
                    0.632939,
                    0.9163666,
                    0.4586151,
                    0.63326573,
                    0.82228065,
                    0.62719053,
                    0.44674626,
                    0.6449413,
                    0.7844299,
                    0.7109407,
                    0.42483827,
                    0.40851548,
                    0.6666175,
                    0.72658247,
                    0.29441452,
                    0.7422972,
                    0.6994506,
                    0.18581374,
                    0.678699,
                    0.6461491,
                    0.47299898,
                    0.3536248,
                    0.4810702,
                    0.74285,
                    0.31706694,
                    0.6093468,
                    0.49133518,
                    0.5890032,
                    0.4567156,
                    0.76003814,
                    0.60391974,
                    0.38637605,
                    0.5953039,
                    0.7995997,
                    0.8849446,
                    0.5667035,
                    0.65296185,
                    0.607227,
                    0.5879688,
                    0.71059513,
                    0.65723944,
                    0.607201,
                    0.2025656,
                    0.7741568,
                    0.8627294,
                    0.20287104,
                    0.7483574,
                    0.8585248,
                    0.47140092,
                    0.31818846,
                    0.57475,
                    0.79051167,
                    0.71378,
                    0.61821806,
                    0.63847613,
                    0.70790005,
                    0.70595086,
                    0.6324633,
                    0.8411397,
                    0.6010942,
                    0.6198959,
                    0.8601436,
                    0.506412,
                    0.8787982,
                    0.7363141,
                    0.5014173,
                    0.81275815,
                    0.43802464,
                    0.8606852,
                    0.90823215,
                    0.6048983,
                    0.5862797,
                    0.89965594,
                    0.21609125,
                    0.58538985,
                    0.74388975,
                    0.6436246,
                    0.68614113,
                    0.8133605,
                    0.8900723,
                    0.33205867,
                    0.5844268,
                    0.57690537,
                    0.55082875,
                    0.18703146,
                    0.7354434,
                    0.7188371,
                    0.73535186,
                    0.49179193,
                    0.60141474,
                    0.739191,
                    0.60902834,
                    0.5209424,
                    0.86558294,
                    0.54551184,
                    0.35854483,
                    0.74914163,
                    0.5216272,
                    0.9627959,
                    0.61666507,
                    0.36979172,
                    0.5368949,
                    0.43976095,
                    0.49865422,
                    0.8608738,
                    0.5655695,
                    0.81297505,
                    0.6252572,
                    0.6706519,
                    0.63214654,
                    0.50118244,
                    0.62449926,
                    0.70553565,
                    0.71667975,
                    0.6725781,
                    0.5215899,
                    0.64888376,
                    0.58514893,
                    0.8322057,
                    0.7555814,
                    0.62717474,
                    0.85005945,
                    0.5694707,
                    0.7400664,
                    0.94494027,
                    0.51169676,
                    0.34803283,
                    0.8389519,
                    0.5855034,
                    0.44634232,
                    0.78732824,
                    0.5224927,
                    0.8763063,
                    0.55516154,
                    0.77261907,
                    0.6934301,
                    0.39077526,
                    0.8408532,
                    0.8121041,
                    0.62674487,
                    0.51155764,
                    0.82464856,
                    0.17810936,
                    0.45419854,
                    0.90070426,
                    0.7507616,
                    0.6304127,
                    0.5382616,
                    0.2548589,
                    0.90965825,
                    0.5757083,
                    0.3881809,
                    0.47091305,
                    0.6259747,
                    0.549566,
                    0.78464717,
                    0.67623615,
                    0.22947542,
                    0.6341044,
                    0.47200313,
                    0.48406413,
                    0.54176223,
                    0.55456895,
                    0.70727295,
                    0.24309844,
                    0.5912585,
                    0.6418238,
                    0.7596782,
                    0.8406271,
                    0.38865283,
                    0.5345519,
                    0.67705786,
                    0.8471737,
                    0.78848296,
                    0.8296676,
                    0.6814048,
                    0.1439574,
                    0.76609397
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=1<br>dataset=test<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "1",
                  "marker": {
                    "color": "#636efa",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "1",
                  "offsetgroup": "1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.8055613,
                    0.6490477,
                    0.53790224,
                    0.86933863,
                    0.7265006,
                    0.7500095,
                    0.5136602,
                    0.6108247,
                    0.6824276,
                    0.6863151,
                    0.81834275,
                    0.70145804,
                    0.6070372,
                    0.76361525,
                    0.47412366,
                    0.29016504,
                    0.7311523,
                    0.6745492,
                    0.5820018,
                    0.45364782,
                    0.7563268,
                    0.6551417,
                    0.8229127,
                    0.71966314,
                    0.65637314,
                    0.6056125,
                    0.40986726,
                    0.6351686,
                    0.8003523,
                    0.6638293,
                    0.5493354,
                    0.38898873,
                    0.89240587,
                    0.7085296,
                    0.6742074,
                    0.73229283,
                    0.6573821,
                    0.8563852,
                    0.73161244,
                    0.5125263,
                    0.73639345,
                    0.5608091,
                    0.6822197,
                    0.6509677,
                    0.77809495,
                    0.7077787,
                    0.42401087,
                    0.55358726,
                    0.96326643,
                    0.73163295,
                    0.7535971,
                    0.36794028,
                    0.25175893,
                    0.5587677,
                    0.9381429,
                    0.65131617,
                    0.707767,
                    0.404175,
                    0.738107,
                    0.5416304,
                    0.57372445,
                    0.76368237,
                    0.8565733,
                    0.7438249,
                    0.12734012,
                    0.88206637,
                    0.6993351,
                    0.6296972,
                    0.61613667,
                    0.7972349,
                    0.69560605,
                    0.6654787,
                    0.99999994,
                    0.48609376,
                    0.47301903,
                    0.55419683,
                    0.6296634,
                    0.4366933,
                    0.5387273,
                    0.7703018,
                    0.9061766,
                    0.936411,
                    0.67313033,
                    0.6194193,
                    0.37134537,
                    0.69975835,
                    0.47905326,
                    0.91620946,
                    0.7396347,
                    0.4008342,
                    0.69292057,
                    0.45092002,
                    0.48777297,
                    0.52802837,
                    0.821909,
                    0.57007617,
                    0.55647916,
                    0.5132801,
                    0.43960932,
                    0.67652094,
                    0.7167868,
                    0.7868749,
                    0.8414399,
                    0.30066487,
                    0.749063,
                    0.9248404,
                    0.29511026,
                    0.8829613,
                    0.9339626,
                    0.80454856,
                    0.66526824,
                    0.88261116,
                    0.8127759,
                    0.63907516,
                    0.5948788,
                    0.779537,
                    0.2734612,
                    0.91090304,
                    0.6181431,
                    0.8291171,
                    0.8786744,
                    0.45176527,
                    0.4650637,
                    0.52472,
                    0.6086645,
                    0.61167514,
                    0.618851,
                    0.59550864,
                    0.5683689,
                    0.54923105,
                    0.72792804,
                    0.5108976,
                    0.17235357,
                    0.8798169,
                    0.45582134,
                    0.6490313,
                    0.5987218,
                    0.8585789,
                    0.57275724,
                    0.6698848,
                    0.708311,
                    0.5580988,
                    0.5629255,
                    0.7218422,
                    0.61439306,
                    0.7661999,
                    0.5670317,
                    0.87961054,
                    0.57681245,
                    -0.025263075,
                    0.7305135,
                    0.72332287,
                    0.23190439,
                    0.49458456,
                    0.7727562,
                    0.63174117,
                    0.63469034,
                    0.24959645,
                    0.22246554,
                    0.6352285,
                    0.7173436,
                    0.8997552,
                    0.71509236,
                    0.68947405,
                    0.66640544,
                    0.2979133,
                    0.7765941,
                    0.61110735,
                    0.7824376,
                    0.45629472,
                    0.7814283,
                    0.94727206,
                    0.57099396,
                    0.7175106,
                    0.8542127,
                    0.7793517,
                    0.63087165,
                    0.3020511,
                    0.4713252,
                    0.8319655,
                    0.65548015,
                    0.43472487,
                    0.44796726,
                    0.73886794,
                    0.65364546,
                    0.5894948,
                    0.67677295,
                    0.34056446,
                    0.2893894,
                    0.67523193,
                    0.9004053,
                    0.7664299,
                    0.070429526,
                    0.7358473,
                    0.596747,
                    0.64145684,
                    0.70690864,
                    0.6161911,
                    0.63442343,
                    0.8034927,
                    0.65551484,
                    0.15030503,
                    0.6618132,
                    0.28821555,
                    0.6077155,
                    0.67070943,
                    0.23067664,
                    0.6557098,
                    0.55279106,
                    0.7644635,
                    0.28628072,
                    0.6755361,
                    0.5833505,
                    0.8240328,
                    0.67054516,
                    0.61436915,
                    0.51077014,
                    0.8207636,
                    0.61361855,
                    0.46712258,
                    0.40974268,
                    0.5336698,
                    0.558656,
                    0.97240293,
                    0.32733288,
                    0.6863151,
                    0.4985424,
                    0.79023767,
                    0.54902637,
                    0.34437552,
                    0.5104175,
                    0.1298324,
                    0.8800378,
                    0.8841439,
                    0.61748284,
                    0.9165488,
                    0.6865411,
                    0.8628597,
                    0.5251879,
                    0.8125164,
                    0.673494,
                    0.6672625,
                    0.73910636,
                    0.49335945,
                    0.27169147,
                    0.27906692,
                    0.676355,
                    0.72769314,
                    0.009032255,
                    0.86168426,
                    0.554593,
                    0.74324536,
                    0.78169584,
                    0.7702316,
                    0.95651525,
                    0.66262144,
                    0.6954001,
                    0.5656505,
                    0.8851712,
                    0.8271407,
                    0.7820752,
                    0.5980351,
                    0.37550944,
                    -0.056717057,
                    0.6385865,
                    0.9871788,
                    0.8663082,
                    0.9173605,
                    0.7812852,
                    0.8634348,
                    0.6151981,
                    0.4764769,
                    0.7129371,
                    0.78437144,
                    0.55460554,
                    0.765633,
                    0.3909395,
                    0.5777054,
                    0.46684957,
                    0.42080644,
                    0.73171866,
                    0.7438681,
                    0.46907872,
                    0.5971506,
                    0.35307956,
                    0.70854855,
                    0.14134333,
                    0.63074076,
                    0.6841741,
                    0.69037616,
                    0.8382874,
                    0.8383749,
                    0.8832028,
                    0.7200073,
                    0.3631567,
                    0.3802071,
                    0.7435165,
                    0.31047463,
                    0.7060255,
                    0.78007823,
                    0.42565107,
                    0.58386827,
                    0.4447159,
                    0.5319243,
                    0.2536062,
                    0.7522945,
                    0.6332803,
                    0.7163948,
                    0.71887827,
                    0.7136859,
                    0.70331776,
                    0.7141161,
                    0.51288867,
                    1,
                    0.65905374,
                    0.558384,
                    0.543979,
                    0.6952025,
                    0.36829284,
                    0.66087174,
                    0.7381764,
                    0.8295663,
                    0.83220756,
                    0.49746192,
                    1,
                    0.5527734,
                    0.3635958,
                    0.552087,
                    0.4612422,
                    0.55174565,
                    0.37825137,
                    0.6775494,
                    0.28022328,
                    0.8762227,
                    0.6699667
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=train<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": true,
                  "type": "histogram",
                  "x": [
                    0.14302981,
                    -0.11057069,
                    -0.17724767,
                    -0.24500036,
                    -0.062143553,
                    0.14316264,
                    0.15678164,
                    0.23773395,
                    -0.07846958,
                    -0.05204554,
                    -0.06655996,
                    -0.1291204,
                    0.443245,
                    -0.23869778,
                    -0.1645332,
                    -0.11218425,
                    0.23315865,
                    -0.38848343,
                    0.114264876,
                    -0.05154337,
                    0.18390255,
                    -0.27475196,
                    -0.05900202,
                    0.22430505,
                    -0.08546959,
                    -0.19892967,
                    -0.009420975,
                    -0.025093991,
                    -0.031790845,
                    -0.0051786974,
                    -0.076038264,
                    -0.024751754,
                    0.18652028,
                    0.011254758,
                    0.108140975,
                    0.09609464,
                    -0.07145849,
                    -0.06851679,
                    0.26567802,
                    -0.08995881,
                    -0.0607494,
                    0.10250999,
                    -0.010349991,
                    0.025486441,
                    -0.036893804,
                    0.020803962,
                    0.072955474,
                    0.12435959,
                    -0.052231353,
                    0.37205556,
                    0.17689727,
                    -0.0055993614,
                    0.19630852,
                    -0.26762488,
                    -0.0044234265,
                    0.021829708,
                    -0.09977905,
                    -0.103702135,
                    -0.08513733,
                    0.43749526,
                    0.36857948,
                    0.10853024,
                    0.13605508,
                    0.15389544,
                    -0.13927853,
                    -0.14963916,
                    -0.3644624,
                    0.52444035,
                    -0.40716577,
                    -0.084710844,
                    -0.07702993,
                    0.28454077,
                    0.17952761,
                    0.19037494,
                    -0.119846344,
                    0.3372389,
                    -0.25199848,
                    -0.14442691,
                    -0.21720155,
                    -0.09363778,
                    0.43611217,
                    -0.25409278,
                    0.059805665,
                    0.08246046,
                    -0.13380173,
                    -0.13812494,
                    -0.050808128,
                    -0.06082993,
                    -0.29077193,
                    -0.28214246,
                    -0.26683524,
                    0.15480486,
                    -0.19142698,
                    -0.07857721,
                    0.059667934,
                    -0.27533692,
                    0.0012983828,
                    -0.41322538,
                    -0.042272426,
                    -0.14912425,
                    -0.13140962,
                    0.33333275,
                    -0.07854002,
                    0.17386948,
                    0.12033723,
                    0.16654651,
                    -0.10265535,
                    -0.12654941,
                    -0.029241713,
                    0.14599761,
                    0.020325862,
                    -0.07572659,
                    -0.044500634,
                    -0.25713974,
                    0.12043542,
                    0.57479286,
                    0.00869727,
                    -0.054210614,
                    -0.1706504,
                    -0.32414228,
                    -0.2668199,
                    -0.26956692,
                    -0.05117539,
                    -0.06610332,
                    -0.10072504,
                    -0.17387511,
                    -0.08241673,
                    -0.15317225,
                    -0.1389603,
                    -0.23053482,
                    -0.03335168,
                    0.0077261436,
                    -0.108453214,
                    -0.16304685,
                    0.101268604,
                    -0.48224252,
                    0.35240975,
                    -0.12512998,
                    -0.078767695,
                    -0.02886478,
                    -0.11151735,
                    0.21486124,
                    0.06401499,
                    -0.07141201,
                    0.24201383,
                    0.38163814,
                    -0.18371437,
                    -0.33098587,
                    -0.17749745,
                    0.088400126,
                    0.0389295,
                    0.30001613,
                    0.14689615,
                    0.2980527,
                    -0.23508236,
                    -0.20070519,
                    -0.20603697,
                    0.09391333,
                    0.413654,
                    0.057752814,
                    -0.0879763,
                    -0.085120164,
                    -0.1986266,
                    0.112001404,
                    -0.24664646,
                    -0.2028898,
                    0.21875925,
                    -0.21640524,
                    0.040345095,
                    -0.22236083,
                    -0.028793441,
                    0.026597003,
                    -0.17050122,
                    -0.17038809,
                    -0.02601072,
                    -0.11057518,
                    0.04983874,
                    -0.09951373,
                    -0.05255925,
                    -0.21587878,
                    -0.042706307,
                    -0.13658924,
                    -0.27548197,
                    -0.080267474,
                    0.08046054,
                    -0.1778766,
                    0.05119922,
                    -0.29639244,
                    -0.13204463,
                    0.019229155,
                    -0.029514017,
                    -0.168669,
                    -0.1870547,
                    0.030565858,
                    -0.3171055,
                    -0.03315684,
                    -0.13100462,
                    -0.025322285,
                    -0.0279369,
                    -0.13280292,
                    -0.071367,
                    -0.17220059,
                    -0.11124851,
                    0.16763113,
                    0.06615649,
                    0.00736435,
                    -0.005432845,
                    0.0028309082,
                    -0.14925274,
                    0.29655105,
                    0.28778684,
                    0.1029157,
                    -0.22709346,
                    0.12467997,
                    -0.26696387,
                    -0.030379076,
                    -0.099905916,
                    0.030037213,
                    -0.19075271,
                    -0.1794763,
                    0.49299556,
                    0.3293071,
                    -0.10895353,
                    0.12664062,
                    -0.092162006,
                    -0.13777371,
                    -0.034772262,
                    -0.05871062,
                    0.026836451,
                    0.02839845,
                    0.13685997,
                    -0.045252584,
                    -0.14783676,
                    -0.1717493,
                    0.05255844,
                    0.3095329,
                    0.10439775,
                    0.28798553,
                    -0.12638304,
                    0.20911206,
                    0.012737125,
                    0.0957955,
                    -0.13159595,
                    0.063837364,
                    0.3295994,
                    -0.13373746,
                    0.29288846,
                    -0.029724248,
                    -0.0043318802,
                    -0.14764512,
                    -0.11397636,
                    -0.37147662,
                    0.34722033,
                    -0.18721312,
                    0.12718824,
                    -0.09654416,
                    -0.0766707,
                    -0.2977832,
                    -0.14676157,
                    0.26951638,
                    -0.082651325,
                    -0.27545825,
                    0.13462485,
                    -0.103285775,
                    0.06710768,
                    0.0015975507,
                    -0.14853989,
                    0.034589108,
                    0.04793806,
                    -0.0091663385,
                    -0.013575685,
                    -0.042185366,
                    -0.084682934,
                    0.058529433,
                    -0.088299036,
                    -0.04100522,
                    -0.1625054,
                    -0.27100295,
                    0.11074404,
                    0.092317745,
                    0.018090539,
                    -0.088989764,
                    -0.22713757,
                    -0.2248414,
                    -0.23065864,
                    -0.12176347,
                    0.4832462,
                    -0.22033268,
                    -0.065273635,
                    0.10010773,
                    -0.029044893,
                    -0.39913887,
                    0.122752,
                    -0.3432,
                    -0.15346213,
                    0.038391877,
                    -0.106522396,
                    -0.24408461,
                    -0.2250687,
                    0.13498728,
                    0.089975856,
                    -0.0010886639,
                    0.14626358,
                    -0.25450054,
                    -0.16206416,
                    -0.35668114,
                    -0.027360622,
                    -0.042533685,
                    -0.028601661,
                    -0.2779112,
                    -0.028973868,
                    -0.26466352,
                    -0.24255127,
                    -0.22748213,
                    0.28742003,
                    0.4260002,
                    0.12781607,
                    0.014489925,
                    0.09618315,
                    -0.13999633,
                    -0.056763954,
                    0.010561457,
                    0.021302488,
                    -0.11718368,
                    -0.13106598,
                    -0.06693572,
                    -0.22939485,
                    0.10453988,
                    -0.03318918,
                    0.29774678,
                    -0.30688155,
                    -0.006827436,
                    -0.023383835,
                    -0.24362536
                  ],
                  "xaxis": "x2",
                  "yaxis": "y2"
                },
                {
                  "alignmentgroup": "True",
                  "bingroup": "x",
                  "hovertemplate": "label=-1<br>dataset=test<br>cosine_similarity_custom=%{x}<br>count=%{y}<extra></extra>",
                  "legendgroup": "-1",
                  "marker": {
                    "color": "#EF553B",
                    "opacity": 0.5,
                    "pattern": {
                      "shape": ""
                    }
                  },
                  "name": "-1",
                  "offsetgroup": "-1",
                  "orientation": "v",
                  "showlegend": false,
                  "type": "histogram",
                  "x": [
                    0.10617652,
                    -0.031975385,
                    -0.10883323,
                    0.32207173,
                    -0.31140262,
                    0.11705714,
                    -0.15120678,
                    -0.13226521,
                    0.32406852,
                    0.028926875,
                    -0.12701692,
                    0.14731352,
                    0.15226293,
                    -0.14873207,
                    0.044569373,
                    -0.195492,
                    0.31632814,
                    0.049260046,
                    -0.1941448,
                    0.12955306,
                    0.39960444,
                    0.10171481,
                    0.01060061,
                    -0.08512529,
                    0.20605738,
                    0.10371626,
                    0.031458076,
                    -0.01223364,
                    -0.22145674,
                    -0.09562346,
                    -0.038915318,
                    -0.036689993,
                    0.12696612,
                    -0.062234133,
                    0.23239751,
                    -0.16191426,
                    -0.112516925,
                    0.26545197,
                    -0.06769511,
                    0.017749624,
                    0.175943,
                    0.17765644,
                    -0.11271724,
                    -0.24490058,
                    0.23518375,
                    0.004909072,
                    -0.05079494,
                    -0.18667297,
                    0.0043339524,
                    -0.16339056,
                    -0.0025770266,
                    -0.032812953,
                    0.34513906,
                    -0.1910095,
                    0.07079941,
                    -0.07229687,
                    -0.09512156,
                    -0.16745691,
                    -0.1844601,
                    -0.101349175,
                    0.41526136,
                    0.81717646,
                    -0.0026841194,
                    -0.09593156,
                    0.0574869,
                    -0.024282647,
                    -0.013999336,
                    -0.40602204,
                    0.569725,
                    -0.053840384,
                    -0.20432736,
                    -0.008172799,
                    -0.00547958,
                    0.180258,
                    -0.12665462,
                    -0.21722375,
                    -0.19613074,
                    -0.023697829,
                    0.24777885,
                    0.1711332,
                    0.537012,
                    -0.087072074,
                    -0.07477147,
                    0.05843446,
                    0.025428122,
                    0.09440827,
                    -0.007145553,
                    -0.10771193,
                    0.25761753,
                    0.044429805,
                    -0.03361663,
                    -0.029891007,
                    -0.06391836,
                    0.1614778,
                    0.0022980184,
                    0.009103167,
                    0.03071209,
                    -0.1727816,
                    -0.20178483,
                    -0.13833278,
                    -0.22644272,
                    -0.23083316,
                    0.03205073,
                    0.33865592,
                    -0.11224474,
                    -0.18638353,
                    -0.15613213,
                    0.13247007,
                    0.091726765,
                    -0.08427659,
                    0.04926182,
                    -0.06531185,
                    0.08664228,
                    0.061082914,
                    0.3546836,
                    0.13567112,
                    -0.12794621,
                    -0.25035465,
                    -0.101458125,
                    -0.07636069,
                    -0.00021772343,
                    -0.06313226,
                    -0.14774837,
                    0.015537703,
                    -0.05826063,
                    0.19063057,
                    -0.18808654,
                    0.09029963,
                    -0.2199365,
                    -0.3122693,
                    0.23579088,
                    0.31103528,
                    -0.041199364,
                    -0.00906378,
                    0.21824978,
                    -0.08886902,
                    0.18546428,
                    0.20851673,
                    -0.15233244,
                    -0.3224695,
                    -0.0325812,
                    -0.04210101,
                    0.16766939,
                    0.26867718,
                    0.111526266,
                    -0.036602326,
                    -0.14308666,
                    0.15467033,
                    -0.24415994,
                    0.24594295,
                    -0.24345881,
                    0.32660398,
                    0.03280199,
                    -0.09476015,
                    -0.29096526,
                    0.032068066,
                    0.17987496,
                    0.03262933,
                    0.20223673,
                    -0.120644696,
                    0.29335102,
                    -0.14304209,
                    -0.021362929,
                    0.06404272,
                    0.05794023,
                    0.0831995,
                    0.0718886,
                    0.2677285,
                    0.35200343,
                    -0.035561264,
                    0.18149698,
                    0.23688413,
                    -0.1472639,
                    0.057499476,
                    0.08719631,
                    -0.09575894,
                    -0.08939485,
                    -0.14045444,
                    0.026247796,
                    -0.11188459,
                    -0.07919278,
                    0.31483328,
                    0.039055776,
                    -0.12993088,
                    -0.1622256,
                    0.31707028,
                    -0.3364733,
                    0.24035016,
                    -0.027344713,
                    0.014690834,
                    -0.16855082,
                    -0.06932548,
                    -0.022266265,
                    -0.22546014,
                    0.008567488,
                    -0.102743,
                    0.13508673,
                    0.15486145,
                    -0.2813347,
                    0.1068097,
                    0.13032803,
                    -0.019828802,
                    0.11795447,
                    0.4221859,
                    0.000759042,
                    0.09165115,
                    -0.03338406,
                    0.08470532,
                    -0.1508812,
                    0.035786882,
                    0.030124856,
                    0.19506164,
                    0.26077172,
                    -0.124081574,
                    -0.31896383,
                    0.10112047,
                    -0.057609923,
                    -0.00080296077,
                    0.23779202,
                    -0.23625362,
                    -0.08160995,
                    0.3032903,
                    -0.08120042,
                    0.16656089,
                    0.10409328,
                    0.18034437,
                    0.046120126,
                    0.1544962,
                    -0.09288298,
                    0.2510819,
                    0.29124844,
                    0.17870253,
                    -0.005379772,
                    -0.30627653,
                    -0.06947586,
                    0.15925092,
                    -0.13554731,
                    0.043542545,
                    0.17783535,
                    -0.20150776,
                    -0.11547584,
                    -0.16441183,
                    0.4270623,
                    0.22010702,
                    0.25666305,
                    -0.04944679,
                    0.057062034,
                    0.0508959,
                    -0.16835836,
                    -0.19262494,
                    -0.1199862,
                    0.15262641,
                    -0.046507467,
                    0.034018174,
                    0.013183228,
                    0.12422307,
                    -0.047742013,
                    0.008541771,
                    0.28857812,
                    -0.062908895,
                    0.08314598,
                    -0.35084343,
                    0.061281245,
                    0.0785819,
                    -0.08573331,
                    -0.13729927,
                    -0.025910659,
                    0.11283098,
                    0.05220833,
                    -0.012076858,
                    -0.022399202,
                    0.27282238,
                    0.20037143,
                    0.49455974,
                    -0.25124055,
                    -0.15144002,
                    -0.17023633,
                    0.17223814,
                    -0.2117701,
                    -0.07220247,
                    0.031898964,
                    0.3294976,
                    -0.11548876,
                    0.11725828,
                    -0.20589651,
                    -0.20220196,
                    -0.120919146,
                    -0.17677295,
                    -0.07525886,
                    0.109793626,
                    -0.034417115,
                    0.16536875,
                    -0.023174297,
                    0.5393097,
                    -0.0039643934,
                    -0.07138218,
                    0.28882954,
                    -0.13014461,
                    -0.23138109,
                    -0.1327854,
                    0.20499766,
                    -0.055300497,
                    -0.0646167,
                    0.009579126,
                    0.041330144,
                    0.0024610406,
                    0.12112454,
                    0.07543635,
                    0.1213533,
                    0.08904015,
                    -0.08634728,
                    -0.06367639,
                    -0.06046631,
                    -0.19301783,
                    -0.14426412,
                    0.07088253,
                    0.38884938,
                    -0.06855903,
                    0.1762231,
                    -0.10999228,
                    0.14951923,
                    -0.29029095,
                    0.12920266,
                    0.035691362,
                    -0.14431979,
                    0.114437565,
                    -0.04307397,
                    -0.1095031,
                    -0.24052213,
                    -0.056761023,
                    -0.2482388,
                    -0.022682087,
                    0.02143445,
                    0.1074058,
                    0.3866141
                  ],
                  "xaxis": "x",
                  "yaxis": "y"
                }
              ],
              "layout": {
                "annotations": [
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=test",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.2425,
                    "yanchor": "middle",
                    "yref": "paper"
                  },
                  {
                    "font": {},
                    "showarrow": false,
                    "text": "dataset=train",
                    "textangle": 90,
                    "x": 0.98,
                    "xanchor": "left",
                    "xref": "paper",
                    "y": 0.7575000000000001,
                    "yanchor": "middle",
                    "yref": "paper"
                  }
                ],
                "barmode": "overlay",
                "legend": {
                  "title": {
                    "text": "label"
                  },
                  "tracegroupgap": 0
                },
                "margin": {
                  "t": 60
                },
                "template": {
                  "data": {
                    "bar": [
                      {
                        "error_x": {
                          "color": "#2a3f5f"
                        },
                        "error_y": {
                          "color": "#2a3f5f"
                        },
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "bar"
                      }
                    ],
                    "barpolar": [
                      {
                        "marker": {
                          "line": {
                            "color": "#E5ECF6",
                            "width": 0.5
                          },
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "barpolar"
                      }
                    ],
                    "carpet": [
                      {
                        "aaxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "baxis": {
                          "endlinecolor": "#2a3f5f",
                          "gridcolor": "white",
                          "linecolor": "white",
                          "minorgridcolor": "white",
                          "startlinecolor": "#2a3f5f"
                        },
                        "type": "carpet"
                      }
                    ],
                    "choropleth": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "choropleth"
                      }
                    ],
                    "contour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "contour"
                      }
                    ],
                    "contourcarpet": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "contourcarpet"
                      }
                    ],
                    "heatmap": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmap"
                      }
                    ],
                    "heatmapgl": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "heatmapgl"
                      }
                    ],
                    "histogram": [
                      {
                        "marker": {
                          "pattern": {
                            "fillmode": "overlay",
                            "size": 10,
                            "solidity": 0.2
                          }
                        },
                        "type": "histogram"
                      }
                    ],
                    "histogram2d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2d"
                      }
                    ],
                    "histogram2dcontour": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "histogram2dcontour"
                      }
                    ],
                    "mesh3d": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "type": "mesh3d"
                      }
                    ],
                    "parcoords": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "parcoords"
                      }
                    ],
                    "pie": [
                      {
                        "automargin": true,
                        "type": "pie"
                      }
                    ],
                    "scatter": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter"
                      }
                    ],
                    "scatter3d": [
                      {
                        "line": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatter3d"
                      }
                    ],
                    "scattercarpet": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattercarpet"
                      }
                    ],
                    "scattergeo": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergeo"
                      }
                    ],
                    "scattergl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattergl"
                      }
                    ],
                    "scattermapbox": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scattermapbox"
                      }
                    ],
                    "scatterpolar": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolar"
                      }
                    ],
                    "scatterpolargl": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterpolargl"
                      }
                    ],
                    "scatterternary": [
                      {
                        "marker": {
                          "colorbar": {
                            "outlinewidth": 0,
                            "ticks": ""
                          }
                        },
                        "type": "scatterternary"
                      }
                    ],
                    "surface": [
                      {
                        "colorbar": {
                          "outlinewidth": 0,
                          "ticks": ""
                        },
                        "colorscale": [
                          [
                            0,
                            "#0d0887"
                          ],
                          [
                            0.1111111111111111,
                            "#46039f"
                          ],
                          [
                            0.2222222222222222,
                            "#7201a8"
                          ],
                          [
                            0.3333333333333333,
                            "#9c179e"
                          ],
                          [
                            0.4444444444444444,
                            "#bd3786"
                          ],
                          [
                            0.5555555555555556,
                            "#d8576b"
                          ],
                          [
                            0.6666666666666666,
                            "#ed7953"
                          ],
                          [
                            0.7777777777777778,
                            "#fb9f3a"
                          ],
                          [
                            0.8888888888888888,
                            "#fdca26"
                          ],
                          [
                            1,
                            "#f0f921"
                          ]
                        ],
                        "type": "surface"
                      }
                    ],
                    "table": [
                      {
                        "cells": {
                          "fill": {
                            "color": "#EBF0F8"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "header": {
                          "fill": {
                            "color": "#C8D4E3"
                          },
                          "line": {
                            "color": "white"
                          }
                        },
                        "type": "table"
                      }
                    ]
                  },
                  "layout": {
                    "annotationdefaults": {
                      "arrowcolor": "#2a3f5f",
                      "arrowhead": 0,
                      "arrowwidth": 1
                    },
                    "autotypenumbers": "strict",
                    "coloraxis": {
                      "colorbar": {
                        "outlinewidth": 0,
                        "ticks": ""
                      }
                    },
                    "colorscale": {
                      "diverging": [
                        [
                          0,
                          "#8e0152"
                        ],
                        [
                          0.1,
                          "#c51b7d"
                        ],
                        [
                          0.2,
                          "#de77ae"
                        ],
                        [
                          0.3,
                          "#f1b6da"
                        ],
                        [
                          0.4,
                          "#fde0ef"
                        ],
                        [
                          0.5,
                          "#f7f7f7"
                        ],
                        [
                          0.6,
                          "#e6f5d0"
                        ],
                        [
                          0.7,
                          "#b8e186"
                        ],
                        [
                          0.8,
                          "#7fbc41"
                        ],
                        [
                          0.9,
                          "#4d9221"
                        ],
                        [
                          1,
                          "#276419"
                        ]
                      ],
                      "sequential": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ],
                      "sequentialminus": [
                        [
                          0,
                          "#0d0887"
                        ],
                        [
                          0.1111111111111111,
                          "#46039f"
                        ],
                        [
                          0.2222222222222222,
                          "#7201a8"
                        ],
                        [
                          0.3333333333333333,
                          "#9c179e"
                        ],
                        [
                          0.4444444444444444,
                          "#bd3786"
                        ],
                        [
                          0.5555555555555556,
                          "#d8576b"
                        ],
                        [
                          0.6666666666666666,
                          "#ed7953"
                        ],
                        [
                          0.7777777777777778,
                          "#fb9f3a"
                        ],
                        [
                          0.8888888888888888,
                          "#fdca26"
                        ],
                        [
                          1,
                          "#f0f921"
                        ]
                      ]
                    },
                    "colorway": [
                      "#636efa",
                      "#EF553B",
                      "#00cc96",
                      "#ab63fa",
                      "#FFA15A",
                      "#19d3f3",
                      "#FF6692",
                      "#B6E880",
                      "#FF97FF",
                      "#FECB52"
                    ],
                    "font": {
                      "color": "#2a3f5f"
                    },
                    "geo": {
                      "bgcolor": "white",
                      "lakecolor": "white",
                      "landcolor": "#E5ECF6",
                      "showlakes": true,
                      "showland": true,
                      "subunitcolor": "white"
                    },
                    "hoverlabel": {
                      "align": "left"
                    },
                    "hovermode": "closest",
                    "mapbox": {
                      "style": "light"
                    },
                    "paper_bgcolor": "white",
                    "plot_bgcolor": "#E5ECF6",
                    "polar": {
                      "angularaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "radialaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "scene": {
                      "xaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "yaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      },
                      "zaxis": {
                        "backgroundcolor": "#E5ECF6",
                        "gridcolor": "white",
                        "gridwidth": 2,
                        "linecolor": "white",
                        "showbackground": true,
                        "ticks": "",
                        "zerolinecolor": "white"
                      }
                    },
                    "shapedefaults": {
                      "line": {
                        "color": "#2a3f5f"
                      }
                    },
                    "ternary": {
                      "aaxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "baxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      },
                      "bgcolor": "#E5ECF6",
                      "caxis": {
                        "gridcolor": "white",
                        "linecolor": "white",
                        "ticks": ""
                      }
                    },
                    "title": {
                      "x": 0.05
                    },
                    "xaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    },
                    "yaxis": {
                      "automargin": true,
                      "gridcolor": "white",
                      "linecolor": "white",
                      "ticks": "",
                      "title": {
                        "standoff": 15
                      },
                      "zerolinecolor": "white",
                      "zerolinewidth": 2
                    }
                  }
                },
                "width": 500,
                "xaxis": {
                  "anchor": "y",
                  "domain": [
                    0,
                    0.98
                  ],
                  "title": {
                    "text": "cosine_similarity_custom"
                  }
                },
                "xaxis2": {
                  "anchor": "y2",
                  "domain": [
                    0,
                    0.98
                  ],
                  "matches": "x",
                  "showticklabels": false
                },
                "yaxis": {
                  "anchor": "x",
                  "domain": [
                    0,
                    0.485
                  ],
                  "title": {
                    "text": "count"
                  }
                },
                "yaxis2": {
                  "anchor": "x2",
                  "domain": [
                    0.515,
                    1
                  ],
                  "matches": "y",
                  "title": {
                    "text": "count"
                  }
                }
              }
            }
          },
          "metadata": {},
          "output_type": "display_data"
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "Test accuracy after customization: 93.6% ± 1.9%\n"
          ]
        }
      ],
      "source": [
        "# plot similarity distribution BEFORE customization\n",
        "px.histogram(\n",
        "    df,\n",
        "    x=\"cosine_similarity\",\n",
        "    color=\"label\",\n",
        "    barmode=\"overlay\",\n",
        "    width=500,\n",
        "    facet_row=\"dataset\",\n",
        ").show()\n",
        "\n",
        "test_df = df[df[\"dataset\"] == \"test\"]\n",
        "a, se = accuracy_and_se(test_df[\"cosine_similarity\"], test_df[\"label\"])\n",
        "print(f\"Test accuracy: {a:0.1%} ± {1.96 * se:0.1%}\")\n",
        "\n",
        "# plot similarity distribution AFTER customization\n",
        "px.histogram(\n",
        "    df,\n",
        "    x=\"cosine_similarity_custom\",\n",
        "    color=\"label\",\n",
        "    barmode=\"overlay\",\n",
        "    width=500,\n",
        "    facet_row=\"dataset\",\n",
        ").show()\n",
        "\n",
        "a, se = accuracy_and_se(test_df[\"cosine_similarity_custom\"], test_df[\"label\"])\n",
        "print(f\"Test accuracy after customization: {a:0.1%} ± {1.96 * se:0.1%}\")\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "id": "XO7iqiVjpgkT",
        "outputId": "a100a9e0-d5aa-46ab-b8a7-4ec6f7bd1cec"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "array([[-1.2566795e+00, -1.5297449e+00, -1.3271648e-01, ...,\n",
              "        -1.2859761e+00, -5.3254390e-01,  4.8364732e-01],\n",
              "       [-1.4826347e+00,  9.2656955e-02, -4.2437232e-01, ...,\n",
              "         1.1872858e+00, -1.0831847e+00, -1.0683593e+00],\n",
              "       [-2.2029283e+00, -1.9703420e+00,  3.1125939e-01, ...,\n",
              "         2.2947595e+00,  5.5780332e-03, -6.0171342e-01],\n",
              "       ...,\n",
              "       [-1.1019799e-01,  1.3599515e+00, -4.7677776e-01, ...,\n",
              "         6.5626711e-01,  7.2359240e-01,  3.0733588e+00],\n",
              "       [ 1.6624762e-03,  4.2648423e-01, -1.1380885e+00, ...,\n",
              "         8.7202555e-01,  9.3173909e-01, -1.6760436e+00],\n",
              "       [ 7.7449006e-01,  4.9213606e-01,  3.5407653e-01, ...,\n",
              "         1.3460466e+00, -1.9509128e-01,  7.7514690e-01]], dtype=float32)"
            ]
          },
          "execution_count": 15,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "best_matrix  # this is what you can multiply your embeddings by\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    }
  ],
  "metadata": {
    "colab": {
      "name": "customized_embeddings_example_with_synthetic_negatives.ipynb",
      "provenance": []
    },
    "interpreter": {
      "hash": "365536dcbde60510dc9073d6b991cd35db2d9bac356a11f5b64279a5e6708b97"
    },
    "kernelspec": {
      "display_name": "Python 3.9.9 ('openai')",
      "language": "python",
      "name": "python3"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "ipython",
        "version": 3
      },
      "file_extension": ".py",
      "mimetype": "text/x-python",
      "name": "python",
      "nbconvert_exporter": "python",
      "pygments_lexer": "ipython3",
      "version": "3.9.9"
    },
    "orig_nbformat": 4
  },
  "nbformat": 4,
  "nbformat_minor": 0
}
